In [1]:
from IPython.display import clear_output

import copy
from clearml import Task, Logger, TaskTypes

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, random_split

from torchmetrics import Precision, Recall, F1Score

from torchvision.datasets import CIFAR10
from torchvision.models import mobilenet_v3_large
import torchvision.transforms as transforms

from tqdm import tqdm

import matplotlib.pyplot as plt

import numpy as np

In [2]:
parameters = {
    'model_name': 'baseline.pth',
    'batch_size': 1000,
    'num_classes': 10,
    'epochs': 25,
    'seed': 47,
    'string': 'my string',
    'aug_params': {
        'color_jitter': {'brightness': 0.5, 'contrast': 0.5, 'saturation': 0.5, 'hue': (-0.1, 0.1)}, 
        'random_flip': {'p': 0.5}
    }, 
    'optimizer_params': {
        'lr': 1e-3
    }
}

In [3]:
torch.manual_seed(parameters['seed'])

<torch._C.Generator at 0x221f028cb70>

In [4]:
train_transforms = transforms.Compose(
    [transforms.ColorJitter(**parameters['aug_params']['color_jitter']),
     transforms.RandomHorizontalFlip(**parameters['aug_params']['random_flip']), 
     transforms.ToTensor()]
)

test_transforms = transforms.Compose(
    [transforms.ToTensor()]
)

patch_transforms = transforms.Compose(
    [transforms.Resize((33, 33)), 
     transforms.ToTensor(),]
)

patch_train_set = CIFAR10("./data", download=True, transform=patch_transforms, train=True)
patch_val_set = CIFAR10("./data", download=True, transform=patch_transforms, train=False)
train_set = CIFAR10("./data", download=True, transform=train_transforms, train=True)
val_set = CIFAR10("./data", download=True, transform=train_transforms, train=False)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [5]:
classes = train_set.classes
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [6]:
def show_sample(images, labels, predicted_labels=None):
    plt.figure(figsize=(20, 10), facecolor='white')
    for i in range(25):
        plt.subplot(5, 5, i + 1)
        image = np.transpose(images[i].numpy(), (1, 2, 0)) 
        plt.imshow(image)
        if predicted_labels is not None:
            plt.title(f'{classes[labels[i]]} / {classes[predicted_labels[i]]}')
        else:
            plt.title(classes[labels[i]])
        plt.axis('off')
    plt.show()

In [7]:
def log_metrics(series, iteration, precision_metric, recall_metric, f1score_metric, loss):
        log.report_histogram("PrecisionByClass", series, iteration=iteration, values=precision_metric.cpu())
        log.report_histogram("RecallByClass", series, iteration=iteration, values=recall_metric.cpu())
        log.report_histogram("F1ScoreByClass", series, iteration=iteration, values=f1score_metric.cpu())
        
        log.report_scalar("Precision", series, iteration=iteration, value=torch.mean(precision_metric).cpu())
        log.report_scalar("Recall", series, iteration=iteration, value=torch.mean(recall_metric).cpu())
        log.report_scalar("F1Score", series, iteration=iteration, value=torch.mean(f1score_metric).cpu())
        log.report_scalar("CrossEntropyLoss", series, iteration=iteration, value=loss)

### Self-supervised learning backbone pretraining

In [8]:
def create_patch_pairs(inputs):
    positions = [(dx, dy) for dx in [-1, 0, 1] for dy in [-1, 0, 1] if (dx, dy) != (0, 0)]
    
    inputs1 = inputs[:, :, 11:22, 11:22]
    inputs2 = [] 
    targets = []
    
    for i in range(inputs.size(0)):
        target = np.random.randint(0, 8)
        dx, dy = positions[target]
        patch_x = slice(11 + dx * 11, 22 + dx*11)
        patch_y = slice(11 + dy * 11, 22 + dy*11)
        inputs2.append(inputs[i, :, patch_x, patch_y])
        targets.append(target)
    inputs2 = torch.stack(inputs2)
    inputs1 = F.interpolate(inputs1, size=(32, 32), mode='bilinear', align_corners=False)
    inputs2 = F.interpolate(inputs2, size=(32, 32), mode='bilinear', align_corners=False)

    
    return inputs1, inputs2, torch.tensor(targets)

In [9]:
class SSLModel(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(960 * 2, 8)

    def forward(self, x1, x2):
        x1 = self.backbone(x1)
        x1 = x1.view(x1.size(0), -1)
        x2 = self.backbone(x2)
        x2 = x2.view(x2.size(0), -1)
        x = torch.cat((x1, x2), dim=1)
        x = self.classifier(x)
        return x

In [10]:
recall = Recall(task='multiclass', average=None, num_classes=8).to(device)
precision = Precision(task='multiclass', average=None, num_classes=8).to(device)
f1score = F1Score(task='multiclass', average=None, num_classes=8).to(device)

In [11]:
task = Task.init(
    project_name='Processing and generating images course', 
    task_name='HW3 SSL model training', 
    task_type=TaskTypes.training)
log = Logger.current_logger()

ClearML Task: created new task id=7f042f2d22e448ea8d4b4d44c1e05deb
2024-03-13 16:49:18,730 - clearml.Task - INFO - Storing jupyter notebook directly as code
ClearML results page: https://app.clear.ml/projects/c907675c01ad4f69a5f853a34e753129/experiments/7f042f2d22e448ea8d4b4d44c1e05deb/output/log


In [12]:
patch_train_loader = torch.utils.data.DataLoader(patch_train_set, batch_size=parameters['batch_size'])
patch_val_loader = torch.utils.data.DataLoader(patch_val_set, batch_size=parameters['batch_size'])

In [13]:
backbone = mobilenet_v3_large(weights=None).features
ssl_model = SSLModel(backbone)
ssl_model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(ssl_model.parameters(), **parameters['optimizer_params'])

In [14]:
for epoch in tqdm(range(2 * parameters['epochs'])):
    train_epoch_loss = 0.0
    processed_data = 0
    ssl_model.train()
    for inputs, _ in patch_train_loader:
        inputs1, inputs2, targets = create_patch_pairs(inputs)
        inputs1, inputs2, targets = inputs1.to(device), inputs2.to(device), targets.to(device)
        optimizer.zero_grad()
        
        outputs = ssl_model(inputs1, inputs2)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        torch.cuda.empty_cache()
        train_epoch_loss += loss.item()
        processed_data += inputs.size(0)
        
        precision(outputs, targets)
        recall(outputs, targets)
        f1score(outputs, targets)
        
    precision_metric = precision.compute()
    precision.reset()
    recall_metric = recall.compute()
    recall.reset()
    f1score_metric = f1score.compute()
    f1score.reset()
    log_metrics("Train", epoch, precision_metric, recall_metric, f1score_metric, train_epoch_loss / processed_data)
    
    ssl_model.eval()
    val_epoch_loss = 0.0
    processed_data = 0
    with torch.no_grad():
        for inputs, _ in patch_val_loader:
            inputs1, inputs2, targets = create_patch_pairs(inputs)
            inputs1, inputs2, targets = inputs1.to(device), inputs2.to(device), targets.to(device)

            outputs = ssl_model(inputs1, inputs2)
            loss = criterion(outputs, targets)
            torch.cuda.empty_cache()
            val_epoch_loss += loss.item()
            processed_data += inputs.size(0)
            
            precision(outputs, targets)
            recall(outputs, targets)
            f1score(outputs, targets)
    
    precision_metric = precision.compute()
    precision.reset()
    recall_metric = recall.compute()
    recall.reset()
    f1score_metric = f1score.compute()
    f1score.reset()
    log_metrics("Validation", epoch, precision_metric, recall_metric, f1score_metric, train_epoch_loss / processed_data)

100%|██████████| 50/50 [27:04<00:00, 32.50s/it]


In [15]:
torch.save(ssl_model.backbone.state_dict(), 'pretrainder_backbone.pth')

2024-03-13 17:16:35,900 - clearml.frameworks - INFO - Found existing registered model id=71e36632d9e64a648b5b6f30af434537 [C:\Users\Anton Volodin\PycharmProjects\processing_and_generating_images_course\pretrainder_backbone.pth] reusing it.


In [16]:
task.close()

### Training features

In [17]:
def get_model(backbone_weights=None):
    model = mobilenet_v3_large(weights=None)
    in_features = model.classifier[3].in_features
    model.classifier[3] = nn.Linear(in_features, parameters['num_classes'])
    if backbone_weights:
        model.features.load_state_dict(torch.load(backbone_weights))
    model.to(device)
    return model

In [18]:
def train(model, train_loader, val_loader):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), **parameters['optimizer_params'])
    for epoch in range(parameters['epochs']):
        model.train()
        train_epoch_loss = 0
        for i, data in enumerate(train_loader):
            images, targets = data
            images = images.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, targets)

            loss.backward()
            optimizer.step()

            train_epoch_loss += loss.item()
            precision(outputs, targets)
            recall(outputs, targets)
            f1score(outputs, targets)

            torch.cuda.empty_cache()
        precision_metric = precision.compute()
        precision.reset()
        recall_metric = recall.compute()
        recall.reset()
        f1score_metric = f1score.compute()
        f1score.reset()
        log_metrics("Train", epoch, precision_metric, recall_metric, f1score_metric, train_epoch_loss / len(train_set))

        model.eval()
        val_epoch_loss = 0
        with torch.no_grad():
            for j, data in enumerate(val_loader):
                images, targets = data
                images = images.to(device)
                targets = targets.to(device)

                outputs = model(images)
                loss = criterion(outputs, targets)

                val_epoch_loss += loss.item()
                precision(outputs, targets)
                recall(outputs, targets)
                f1score(outputs, targets)

                torch.cuda.empty_cache()
        precision_metric = precision.compute()
        precision.reset()
        recall_metric = recall.compute()
        recall.reset()
        f1score_metric = f1score.compute()
        f1score.reset()
        log_metrics("Validation", epoch, precision_metric, recall_metric, f1score_metric, val_epoch_loss / len(val_set))

In [19]:
recall = Recall(task='multiclass', average=None, num_classes=parameters['num_classes']).to(device)
precision = Precision(task='multiclass', average=None, num_classes=parameters['num_classes']).to(device)
f1score = F1Score(task='multiclass', average=None, num_classes=parameters['num_classes']).to(device)

### Exp 1: 100% of data

In [20]:
task = Task.init(
    project_name='Processing and generating images course', 
    task_name='HW3 Exp 1 pretrained', 
    task_type=TaskTypes.training)
log = Logger.current_logger()

train_loader = torch.utils.data.DataLoader(train_set, batch_size=parameters['batch_size'])
val_loader = torch.utils.data.DataLoader(val_set, batch_size=parameters['batch_size'])

model = get_model(backbone_weights='pretrainder_backbone.pth')
train(model, train_loader, val_loader)
task.close()

ClearML Task: created new task id=61b167925a0f4d959b1c5d28566b632d
ClearML results page: https://app.clear.ml/projects/c907675c01ad4f69a5f853a34e753129/experiments/61b167925a0f4d959b1c5d28566b632d/output/log


### Exp 2: 50% of data

In [21]:
train_set_half, _ = random_split(train_set, [len(train_set) // 2, len(train_set) // 2])
train_loader = torch.utils.data.DataLoader(train_set_half, batch_size=parameters['batch_size'])

In [22]:
task = Task.init(
    project_name='Processing and generating images course', 
    task_name='HW3 Exp 2 not pretrained', 
    task_type=TaskTypes.training)
log = Logger.current_logger()
model = get_model(backbone_weights=None)
train(model, train_loader, val_loader)
task.close()

ClearML Task: created new task id=84698a15ca0d48398bc157c09588fa25
ClearML results page: https://app.clear.ml/projects/c907675c01ad4f69a5f853a34e753129/experiments/84698a15ca0d48398bc157c09588fa25/output/log


In [23]:

task = Task.init(
    project_name='Processing and generating images course', 
    task_name='HW3 Exp 2 pretrained', 
    task_type=TaskTypes.training)
log = Logger.current_logger()
model = get_model(backbone_weights='pretrainder_backbone.pth')
train(model, train_loader, val_loader)
task.close()

ClearML Task: created new task id=4e0141caee3e4810b149d4386138ab2d
ClearML results page: https://app.clear.ml/projects/c907675c01ad4f69a5f853a34e753129/experiments/4e0141caee3e4810b149d4386138ab2d/output/log


### Exp 3: 10% of data

In [24]:
train_set_10, _ = random_split(train_set, [len(train_set) // 10, len(train_set) - len(train_set) // 10 ])
train_loader = torch.utils.data.DataLoader(train_set_10, batch_size=parameters['batch_size'])

In [25]:
task = Task.init(
    project_name='Processing and generating images course', 
    task_name='HW3 Exp 3 not pretrained', 
    task_type=TaskTypes.training)
log = Logger.current_logger()


model = get_model(backbone_weights=None)
train(model, train_loader, val_loader)
task.close()

ClearML Task: created new task id=7e51ea5e9d8447ada915f5c402a363d3
ClearML results page: https://app.clear.ml/projects/c907675c01ad4f69a5f853a34e753129/experiments/7e51ea5e9d8447ada915f5c402a363d3/output/log


In [26]:
task = Task.init(
    project_name='Processing and generating images course', 
    task_name='HW3 Exp 3 pretrained', 
    task_type=TaskTypes.training)
log = Logger.current_logger()


model = get_model(backbone_weights='pretrainder_backbone.pth')
train(model, train_loader, val_loader)
task.close()

ClearML Task: created new task id=e5fdb4e6ae1c40b1b275295e645b7c8b
ClearML results page: https://app.clear.ml/projects/c907675c01ad4f69a5f853a34e753129/experiments/e5fdb4e6ae1c40b1b275295e645b7c8b/output/log
