In [1]:
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import DataLoader
import csv
import random
import numpy as np
import cv2
import mimetypes
import fnmatch
import seaborn as sns
from collections import Counter, defaultdict
import torchvision
from torchvision import transforms
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import time
from progress.bar import IncrementalBar
from tqdm.notebook import tqdm_notebook
from sklearn.metrics import confusion_matrix
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import roc_auc_score
import io
from pytorch_lightning.callbacks import Callback
from datetime import datetime, date, time
from PIL import Image
import itertools 
from sklearn.model_selection import train_test_split
import torchvision.models as models
from ViT.models.modeling import VisionTransformer, CONFIGS
from urllib.request import urlretrieve

import sys

sys.path.insert(0, '/home/anna/Desktop/Diploma/Learning/Sources/')


from callbacks_2classes_x10 import plot_confusion_matrix
from torch.nn import functional as F
from callbacks_2classes_x10 import get_true_classes
from callbacks_2classes_x10 import get_predicted_classes
from callbacks_2classes_x10 import get_classes_probs
from callbacks_2classes_x10 import callback
from data_tools import CatsDataset

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

labels_map_2cl = {
    "NotCat": 0,
    "Cat": 1,
}

transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = CatsDataset('train_paths.txt', transform = transform)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)

val_dataset = CatsDataset('val_paths.txt', transform = transform)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=True)

Using cuda device


In [3]:
os.makedirs("attention_data", exist_ok=True)
if not os.path.isfile("attention_data/ilsvrc2012_wordnet_lemmas.txt"):
    urlretrieve("https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt", "attention_data/ilsvrc2012_wordnet_lemmas.txt")
if not os.path.isfile("attention_data/ViT-B_16-224.npz"):
    urlretrieve("https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16-224.npz", "attention_data/ViT-B_16-224.npz")

In [4]:
# Prepare Model
config = CONFIGS['ViT-B_16']
model = VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=True)
model.eval()
None

In [5]:
model.head = nn.Linear(768, 1)

In [6]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

model.apply(init_weights)

model.to(device)
None

  torch.nn.init.xavier_uniform(m.weight)


In [7]:
logdir = "../../Logits/ViT_B_16_cats_logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")

In [8]:
writer = SummaryWriter(logdir)
vit_callback = callback()

In [9]:
# model.load_state_dict(torch.load("../../Logits/SavedNN/Saved_ViT_L_16_pets/" + str(6)))
# model.to(device)
# None

In [10]:
pos_weight = torch.from_numpy(np.array([0.5])).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.999))

In [1]:
epochs_num = 15

for epoch in tqdm_notebook(range(epochs_num), desc='epochs'):  # loop over the dataset multiple times
    
    vit_callback.on_epoch_begin(epoch) 
    
    running_loss = 0.0
    
    classes = []
    true_classes= []
    
#     if epoch == 6:
#         optimizer = optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.999))

    for data in tqdm_notebook(train_dataloader, desc='one epoch training'):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        true_classes.append(labels.cpu().detach().numpy().astype(int))
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        outputs = model(inputs)[0]
        if outputs.shape > torch.Size([1]):
            outputs = outputs.squeeze()
        if outputs.shape < torch.Size([1]):
                outputs = outputs.unsqueeze(0)
        
        # СИГМОИДА И ГРАНИЦА
        probs = torch.sigmoid(outputs)
        probs = probs.cpu().detach().numpy().astype(float)
        classes.append(probs)
        
        # print(outputs.shape(), labels.shape())
        outputs = outputs.to(device)
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    running_loss /= len(true_classes)
    
    val_classes = []
    val_true_classes = []
    val_loss = 0.0
    
    for data in tqdm_notebook(val_dataloader, desc='validation'):
        # get the inputs; data is a list of [inputs, labels]
        val_inputs, val_labels = data
        val_true_classes.append(val_labels.cpu().detach().numpy().astype(int))
        val_inputs = val_inputs.to(device)
        val_labels = val_labels.to(device)
    
        # forward 
        with torch.no_grad():
            val_outputs = model(val_inputs)[0]
            if val_outputs.shape > torch.Size([1]):
                val_outputs = val_outputs.squeeze()
            if val_outputs.shape < torch.Size([1]):
                val_outputs = val_outputs.unsqueeze(0)
            loss = criterion(val_outputs, val_labels.float())
            val_loss += loss.item()
            
        val_probs = torch.sigmoid(val_outputs)
        val_probs = val_probs.cpu().detach().numpy().astype(float)
        val_classes.append(val_probs) 
        
    val_loss /= len(val_true_classes)

    vit_callback.on_epoch_end(true_classes, classes, val_true_classes, val_classes,
                          ["NotCat", "Cat"],
                          running_loss, val_loss, writer)
    
    torch.save(model.state_dict(), "../../Logits/SavedNN/Saved_ViT_B_16_cats/" + str(epoch))

print('Finished Training')

NameError: name 'tqdm_notebook' is not defined

## Best Result - 2 or 4 epoch - ViT_B_32

## Best Result - 6 epoch - ViT_B_16