# Individual Classifier for Datasets: FGVC & Naturalist

*Code Writer: Chaeeun Ryu*

In [1]:
!nvidia-smi

Sat Mar 18 08:58:42 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.43.04    Driver Version: 515.43.04    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    On   | 00000000:01:00.0 Off |                  Off |
| 30%   29C    P8    26W / 300W |   1315MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000    On   | 00000000:25:00.0 Off |                  Off |
| 30%   28C    P8    31W / 300W |    396MiB / 49140MiB |      0%      Default |
|       

# Import Libraries

In [2]:
import argparse
import os
import random
import time
import warnings
import sys
import numpy as np
import torch
import torch.nn as nn 
import torchvision.transforms as transforms
from collections import Counter, OrderedDict
import torch.optim
from sklearn.metrics import confusion_matrix
from torchvision.datasets import OxfordIIITPet
import os
from torchvision.datasets import FGVCAircraft
from torchvision.datasets import INaturalist
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
import torchvision.models as models
import glob
import copy
# from torchvision.models import resnet50, ResNet50_Weights
import torchvision
import matplotlib.pyplot as plt
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
device =  torch.device('cuda:4')
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

In [3]:
device

device(type='cuda', index=4)

In [4]:
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore

# Config

In [5]:
from dataclasses import dataclass
global args
@dataclass
class ARGS():
    LR = 0.0005
    EPOCHS = 200
    BATCHSIZE = 500
    MOMENTUM = 0.9
    WORKERS = 0
    WEIGHTDECAY = 0
    T_MAX = 150
    CLS_CLASS = 37
    SEG_CLASS = 3
    SIZE = 224
    SEED = 38
    MAX_PATIENCE = 15
    
args = ARGS()

In [6]:
seed_everything(args.SEED)

## Global Transformations

In [7]:
transform_train = transforms.Compose([
        transforms.Resize((args.SIZE,args.SIZE)),
        transforms.RandomCrop(args.SIZE, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

transform_val = transforms.Compose([
        transforms.Resize((args.SIZE,args.SIZE)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

# FGVCAircraft

In [8]:
classification_train_dataset_variant = FGVCAircraft(root='./data', split="trainval", annotation_level  = 'variant',download=True, transform=transform_train)
classification_val_dataset_variant = FGVCAircraft(root='./data', split="test", annotation_level  = 'variant',download=True, transform=transform_val)

classification_train_dataset_family = FGVCAircraft(root='./data', split="trainval", annotation_level  = 'family',download=True, transform=transform_train)
classification_val_dataset_family = FGVCAircraft(root='./data', split="test", annotation_level  = 'family',download=True, transform=transform_val)

classification_train_dataset_manufacturer = FGVCAircraft(root='./data', split="trainval", annotation_level  = 'manufacturer',download=True, transform=transform_train)
classification_val_dataset_manufacturer = FGVCAircraft(root='./data', split="test", annotation_level  = 'manufacturer',download=True, transform=transform_val)

In [9]:
len(classification_train_dataset_variant.classes)

100

In [10]:
global data_classes
data_classes = dict()
data_classes['variant'] = len(classification_train_dataset_variant.classes)
data_classes['family'] = len(classification_train_dataset_family.classes)
data_classes['manufacturer'] = len(classification_train_dataset_manufacturer.classes)

In [11]:
data_classes

{'variant': 100, 'family': 70, 'manufacturer': 30}

In [12]:
global data_loaders
data_loaders = dict()
d_name = "variant"
variant_train_loader = torch.utils.data.DataLoader(
    classification_train_dataset_variant, batch_size=args.BATCHSIZE, shuffle=True,
    num_workers=args.WORKERS, pin_memory=True)
variant_val_loader = torch.utils.data.DataLoader(
    classification_val_dataset_variant, batch_size=args.BATCHSIZE, shuffle=False,
    num_workers=args.WORKERS, pin_memory=True)
data_loaders[d_name] = {'train_loader':variant_train_loader,'val_loader':variant_val_loader}

d_name = "family"
family_train_loader = torch.utils.data.DataLoader(
    classification_train_dataset_family, batch_size=args.BATCHSIZE, shuffle=True,
    num_workers=args.WORKERS, pin_memory=True)
family_val_loader = torch.utils.data.DataLoader(
    classification_val_dataset_family, batch_size=args.BATCHSIZE, shuffle=False,
    num_workers=args.WORKERS, pin_memory=True)
data_loaders[d_name] = {'train_loader':family_train_loader,'val_loader':family_val_loader}

d_name = "manufacturer"
manufacturer_train_loader = torch.utils.data.DataLoader(
    classification_train_dataset_manufacturer, batch_size=args.BATCHSIZE, shuffle=True,
    num_workers=args.WORKERS, pin_memory=True)
manufacturer_val_loader = torch.utils.data.DataLoader(
    classification_val_dataset_manufacturer, batch_size=args.BATCHSIZE, shuffle=False,
    num_workers=args.WORKERS, pin_memory=True)
data_loaders[d_name] = {'train_loader':manufacturer_train_loader,'val_loader':manufacturer_val_loader}

In [13]:
data_loaders.keys()

dict_keys(['variant', 'family', 'manufacturer'])

# Classification

## FGVC 1. Variant

In [14]:
d_name = 'variant'
n_label = data_classes[d_name]
print(f"number of labels for data {d_name}: {n_label}")

number of labels for data variant: 100


In [15]:
'''ResNet in PyTorch.
Reference:
https://github.com/kuangliu/pytorch-cifar
'''
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)
        self.relu = nn.ReLU(inplace = True) 
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=args.CLS_CLASS):
        super(ResNet, self).__init__()
        self.in_planes = 64
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
#         print("block.expansion:",block.expansion) == 1
#         self.
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(512*block.expansion, num_classes)
        
#         self.linear = nn.Linear(8192, num_classes)
        

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.maxpool(out)
        out1 = self.layer1(out)
        out2 = self.layer2(out1)
        out3 = self.layer3(out2)
        out4 = self.layer4(out3)
#         out = F.avg_pool2d(out4, 4)
        out = self.avgpool(out4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out#, [out1, out2, out3, out4]


def ResNet50(num_classes):
    return ResNet(Bottleneck, [3,4,6,3],num_classes)

In [16]:
resnet50 = ResNet50(n_label)
resnet50

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (shortcut): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1)

In [17]:
cls_criterion = nn.CrossEntropyLoss().to(device)

In [18]:
def train_cls(epoch,train_loader,model,seg_criterion, optimizer, device=device, cls_criterion = cls_criterion):
    assert seg_criterion == None
    ss_losses = []
    save_ss_acc = []
    
    model.to(device)
    model.train()
    if epoch == 0:
        for image, label in tqdm(train_loader):
            image,label = image.to(device), label.to(device)
            surrogate_out = model(image)
            surrogate_pred = torch.argmax(surrogate_out,axis = 1)
            cls_loss = cls_criterion(surrogate_out, label)
            loss = cls_loss
            ss_losses.append(cls_loss.item())
            acc1 = (torch.sum((surrogate_pred==label)*1)/len(label)).item()
            save_ss_acc.append(acc1)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        return np.mean(ss_losses),np.mean(save_ss_acc)
    else:
        for image, label in train_loader:
            image,label = image.to(device), label.to(device)
            surrogate_out = model(image)
            surrogate_pred = torch.argmax(surrogate_out,axis = 1)
            cls_loss = cls_criterion(surrogate_out, label)
            loss = cls_loss
            ss_losses.append(cls_loss.item())
            acc1 = (torch.sum((surrogate_pred==label)*1)/len(label)).item()
            save_ss_acc.append(acc1)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        return np.mean(ss_losses),np.mean(save_ss_acc)

In [19]:
def eval_cls(val_loader, model, seg_criterion, device, cls_criterion = cls_criterion):
    assert seg_criterion == None
    ss_losses = [] 
    ss_acc = []
    model.eval()
    with torch.no_grad():
        for image,label in val_loader:
            image,label = image.to(device),label.to(device)
            surrogate_out = model(image)
            surrogate_pred = torch.argmax(surrogate_out, axis = 1)
            acc1 = (torch.sum((surrogate_pred==label)*1)/len(surrogate_pred)).item()
            ss_acc.append(acc1)
            
            cls_loss = cls_criterion(surrogate_out,label)
            loss = cls_loss
            ss_losses.append(cls_loss.item())
            
    return np.mean(ss_losses), np.mean(ss_acc)

In [20]:
ch_unet_optimizer = torch.optim.SGD(resnet50.parameters(), args.LR,
                            momentum=args.MOMENTUM,
                            weight_decay=args.WEIGHTDECAY)
ch_unet_scheduler = CosineAnnealingLR(ch_unet_optimizer, T_max = args.T_MAX)

In [None]:
from tqdm.notebook import tqdm

save_ss_train_loss = []
save_ss_val_loss = []
save_ss_train_acc = []
save_ss_val_acc = []

ss_best_val_acc = 0.

that_time_acc= None
max_patience = args.MAX_PATIENCE
resnet50.to(device)
seg_criterion = None

print("training starting..")

for epoch in tqdm(range(args.EPOCHS)):
    ss_train_loss,ss_train_acc = train_cls(epoch,data_loaders[d_name]['train_loader'], resnet50, seg_criterion, ch_unet_optimizer,device, cls_criterion)
    ss_val_loss,ss_val_acc = eval_cls(data_loaders[d_name]['val_loader'], resnet50, seg_criterion, device, cls_criterion)
    
    #self-supervision
    save_ss_train_loss.append(ss_train_loss)
    save_ss_train_acc.append(ss_train_acc)
    save_ss_val_loss.append(ss_val_loss)
    save_ss_val_acc.append(ss_val_acc)
    
    print(f"======= epoch {epoch} ========")
    print("avg train acc:",ss_train_acc,"avg val acc:",ss_val_acc, "ss train loss:",ss_train_loss, "ss val loss:",ss_val_loss)
    
    if epoch%20 == 0:
        print(f"best average acc so far:{ss_best_val_acc}")
    if ss_val_acc > ss_best_val_acc:
        print("model updated at epoch:",epoch,"avg val acc:",ss_val_acc)
        best_model = copy.deepcopy(resnet50)
        ss_best_val_acc = ss_val_acc
        patience = 0
    else:
        patience +=1
    
    if patience > max_patience:
        print("patience overloaded..! stop learning.........")
        break
        
    ch_unet_scheduler.step()

#지금까지 high score: 0.87

training starting..


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

avg train acc: 0.010569718167451876 avg val acc: 0.009428571377481734 ss train loss: 4.731159346444266 ss val loss: 4.6304304259163995
best average acc so far:0.0
model updated at epoch: 0 avg val acc: 0.009428571377481734
avg train acc: 0.0117125753313303 avg val acc: 0.009428571643573897 ss train loss: 4.618389742715018 ss val loss: 4.628241130283901
model updated at epoch: 1 avg val acc: 0.009428571643573897


In [None]:
max(save_ss_val_acc)

In [None]:
log_dict = dict()
# log_dict['save_train_loss'] = save_train_loss
# log_dict['save_val_loss'] = save_val_loss
# log_dict['save_train_dice'] = save_train_dice
# log_dict['save_val_dice'] = save_val_dice
log_dict['save_ss_train_acc'] = save_ss_train_acc
log_dict['save_ss_val_acc'] = save_ss_val_acc
log_dict['save_ss_train_loss'] = save_ss_train_loss
log_dict['save_ss_val_loss'] = save_ss_val_loss

In [None]:
torch.save(log_dict,f"./{d_name}_resnet50_only_cls_log.pt")
torch.save({"model_w":resnet50.state_dict(),"best_model_w":best_model.state_dict()},f"./{d_name}resnet50_only_cls_w.pt")

In [None]:
plt.title("Saved loss for only classification (ResNet50)")
plt.plot(save_ss_train_loss,label = "train loss")
plt.plot(save_ss_val_loss, label = "valid loss")
plt.legend()
plt.show()

In [None]:
plt.title("Saved acc for only classification (ResNet50)")
plt.plot(save_ss_train_acc,label = "train")
plt.plot(save_ss_val_acc,label = "valid")
plt.legend()
plt.show()

## FGVC 2. Family

In [None]:
d_name = 'family'
n_label = data_classes[d_name]
print(f"number of labels for data {d_name}: {n_label}")

In [None]:
resnet50 = ResNet50(n_label)
resnet50

In [None]:
ch_unet_optimizer = torch.optim.SGD(resnet50.parameters(), args.LR,
                            momentum=args.MOMENTUM,
                            weight_decay=args.WEIGHTDECAY)
ch_unet_scheduler = CosineAnnealingLR(ch_unet_optimizer, T_max = args.T_MAX)

In [None]:
from tqdm.notebook import tqdm

save_ss_train_loss = []
save_ss_val_loss = []
save_ss_train_acc = []
save_ss_val_acc = []

ss_best_val_acc = 0.

that_time_acc= None
max_patience = args.MAX_PATIENCE
resnet50.to(device)
seg_criterion = None

print("training starting..")

for epoch in tqdm(range(args.EPOCHS)):
    ss_train_loss,ss_train_acc = train_cls(epoch,data_loaders[d_name]['train_loader'], resnet50, seg_criterion, ch_unet_optimizer,device, cls_criterion)
    ss_val_loss,ss_val_acc = eval_cls(data_loaders[d_name]['val_loader'], resnet50, seg_criterion, device, cls_criterion)
    
    #self-supervision
    save_ss_train_loss.append(ss_train_loss)
    save_ss_train_acc.append(ss_train_acc)
    save_ss_val_loss.append(ss_val_loss)
    save_ss_val_acc.append(ss_val_acc)
    
    print(f"======= epoch {epoch} ========")
    print("avg train acc:",ss_train_acc,"avg val acc:",ss_val_acc, "ss train loss:",ss_train_loss, "ss val loss:",ss_val_loss)
    
    if epoch%20 == 0:
        print(f"best average acc so far:{ss_best_val_acc}")
    if ss_val_acc > ss_best_val_acc:
        print("model updated at epoch:",epoch,"avg val acc:",ss_val_acc)
        best_model = copy.deepcopy(resnet50)
        ss_best_val_acc = ss_val_acc
        patience = 0
    else:
        patience +=1
    
    if patience > max_patience:
        print("patience overloaded..! stop learning.........")
        break
        
    ch_unet_scheduler.step()

#지금까지 high score: 0.87

In [None]:
max(save_ss_val_acc)

In [None]:
log_dict = dict()
# log_dict['save_train_loss'] = save_train_loss
# log_dict['save_val_loss'] = save_val_loss
# log_dict['save_train_dice'] = save_train_dice
# log_dict['save_val_dice'] = save_val_dice
log_dict['save_ss_train_acc'] = save_ss_train_acc
log_dict['save_ss_val_acc'] = save_ss_val_acc
log_dict['save_ss_train_loss'] = save_ss_train_loss
log_dict['save_ss_val_loss'] = save_ss_val_loss

In [None]:
torch.save(log_dict,f"./{d_name}_resnet50_only_cls_log.pt")
torch.save({"model_w":resnet50.state_dict(),"best_model_w":best_model.state_dict()},f"./{d_name}resnet50_only_cls_w.pt")

In [None]:
plt.title(f"Saved loss for only classification (ResNet50) ({d_name})")
plt.plot(save_ss_train_loss,label = "train loss")
plt.plot(save_ss_val_loss, label = "valid loss")
plt.legend()
plt.show()

In [None]:
plt.title(f"Saved acc for only classification (ResNet50) ({d_name})")
plt.plot(save_ss_train_acc,label = "train")
plt.plot(save_ss_val_acc,label = "valid")
plt.legend()
plt.show()

## FGVC 3. manufacturer

In [None]:
d_name = 'manufacturer'
n_label = data_classes[d_name]
print(f"number of labels for data {d_name}: {n_label}")

In [None]:
resnet50 = ResNet50(n_label)
resnet50

In [None]:
ch_unet_optimizer = torch.optim.SGD(resnet50.parameters(), args.LR,
                            momentum=args.MOMENTUM,
                            weight_decay=args.WEIGHTDECAY)
ch_unet_scheduler = CosineAnnealingLR(ch_unet_optimizer, T_max = args.T_MAX)

In [None]:
from tqdm.notebook import tqdm

save_ss_train_loss = []
save_ss_val_loss = []
save_ss_train_acc = []
save_ss_val_acc = []

ss_best_val_acc = 0.

that_time_acc= None
max_patience = args.MAX_PATIENCE
resnet50.to(device)
seg_criterion = None

print("training starting..")

for epoch in tqdm(range(args.EPOCHS)):
    ss_train_loss,ss_train_acc = train_cls(epoch,data_loaders[d_name]['train_loader'], resnet50, seg_criterion, ch_unet_optimizer,device, cls_criterion)
    ss_val_loss,ss_val_acc = eval_cls(data_loaders[d_name]['val_loader'], resnet50, seg_criterion, device, cls_criterion)
    
    #self-supervision
    save_ss_train_loss.append(ss_train_loss)
    save_ss_train_acc.append(ss_train_acc)
    save_ss_val_loss.append(ss_val_loss)
    save_ss_val_acc.append(ss_val_acc)
    
    print(f"======= epoch {epoch} ========")
    print("avg train acc:",ss_train_acc,"avg val acc:",ss_val_acc, "ss train loss:",ss_train_loss, "ss val loss:",ss_val_loss)
    
    if epoch%20 == 0:
        print(f"best average acc so far:{ss_best_val_acc}")
    if ss_val_acc > ss_best_val_acc:
        print("model updated at epoch:",epoch,"avg val acc:",ss_val_acc)
        best_model = copy.deepcopy(resnet50)
        ss_best_val_acc = ss_val_acc
        patience = 0
    else:
        patience +=1
    
    if patience > max_patience:
        print("patience overloaded..! stop learning.........")
        break
        
    ch_unet_scheduler.step()

#지금까지 high score: 0.87

In [None]:
max(save_ss_val_acc)

In [None]:
log_dict = dict()
# log_dict['save_train_loss'] = save_train_loss
# log_dict['save_val_loss'] = save_val_loss
# log_dict['save_train_dice'] = save_train_dice
# log_dict['save_val_dice'] = save_val_dice
log_dict['save_ss_train_acc'] = save_ss_train_acc
log_dict['save_ss_val_acc'] = save_ss_val_acc
log_dict['save_ss_train_loss'] = save_ss_train_loss
log_dict['save_ss_val_loss'] = save_ss_val_loss

In [None]:
torch.save(log_dict,f"./{d_name}_resnet50_only_cls_log.pt")
torch.save({"model_w":resnet50.state_dict(),"best_model_w":best_model.state_dict()},f"./{d_name}resnet50_only_cls_w.pt")

In [None]:
plt.title(f"Saved loss for only classification (ResNet50) ({d_name})")
plt.plot(save_ss_train_loss,label = "train loss")
plt.plot(save_ss_val_loss, label = "valid loss")
plt.legend()
plt.show()

In [None]:
plt.title(f"Saved acc for only classification (ResNet50) ({d_name})")
plt.plot(save_ss_train_acc,label = "train")
plt.plot(save_ss_val_acc,label = "valid")
plt.legend()
plt.show()