# Get data

In [1]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'


# grayscale_transform = transforms.RandomGrayscale(0.1)

fext_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


pdn_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_path = '/kaggle/input/imagenetmini-1000/imagenet-mini/train'

fext_dataset = ImageFolder(root=train_path, transform=fext_transform)
pdn_dataset = ImageFolder(root=train_path, transform=pdn_transform)

dataloader = DataLoader(fext_dataset, batch_size=50, shuffle=True)

# Feature extractor (pretrained wide resnet 101

In [3]:
class FeatureExtractor(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        backbone = models.wide_resnet101_2(weights='DEFAULT')
        
        
        self.initial = nn.Sequential(*list(backbone.children())[:5])
        self.layr2 = backbone.layer2
        self.layr3 = backbone.layer3
        
        self.unfold = nn.Unfold(kernel_size=3, stride=1, padding=1, dilation=1)
        
        
    def forward(self, x):
        
        x = self.initial(x)      
        x1 = self.layr2(x)
        x2 = self.layr3(x1)
        
        x2_upsampled = F.interpolate(x2, size=(64, 64), mode='bilinear', align_corners=False)
        
        
        concatenated_tensor = torch.cat((x1, x2_upsampled), dim=1)
        
        desired_channels = 384
        aggregated_tensor = F.interpolate(concatenated_tensor.unsqueeze(0), size=(desired_channels, 64, 64), mode='trilinear').squeeze(0)
        
        
        return aggregated_tensor
    
    
fext = FeatureExtractor()

for name, param in fext.named_parameters():
    param.requires_grad = False

if torch.cuda.device_count() > 1:
    fext = nn.DataParallel(fext)
    
fext = fext.to(device)

Downloading: "https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth" to /root/.cache/torch/hub/checkpoints/wide_resnet101_2-d733dc28.pth
100%|██████████| 485M/485M [00:03<00:00, 151MB/s]  


# Teacher network (PDN)

In [4]:
class TeacherNetwork(nn.Module):
    def __init__(self):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=128, kernel_size=4, stride=1, padding=3),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
            
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=1, padding=3),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
            
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=384, kernel_size=4, stride=1, padding=0),
        )

    def forward(self, x):
        x = self.features(x)
        return x

pdn = TeacherNetwork()

if torch.cuda.device_count() > 1:
    pdn = nn.DataParallel(pdn)

pdn = pdn.to(device)

In [5]:
mu_channel = []
sig_channel = []

fext.eval()
for i, (images, _) in enumerate(dataloader):

    if i>15:
        break

    with torch.no_grad():
        fext_out = fext(images)   

    if i == 0:
        sequence = fext_out

    else:
        sequence = torch.cat((sequence, fext_out), dim=0)


sequence = sequence.mean(dim=0)

for tensr in sequence:
    mu_channel.append(tensr.mean().item())
    sig_channel.append(tensr.std().item())
        
        
fext_normalize = transforms.Normalize(mean=mu_channel, std=sig_channel)

In [6]:
import gc
gc.collect()

39

In [7]:
criterion = nn.MSELoss()

optimizer = torch.optim.Adam(pdn.parameters(), lr=1e-4, weight_decay=1e-5)

In [8]:
random.seed(9)

fext.eval()
pdn.train()

for epoch in range(5000):
    
    for bix in range(16):
        
        img_id = random.randint(0, len(fext_dataset)-1)
        fext_image = fext_dataset[img_id][0]
        pdn_image = pdn_dataset[img_id][0]
#         image = grayscale_transform(image)
        
        fext_img = fext_image.unsqueeze(0).to(device)
        pdn_img = pdn_image.unsqueeze(0).to(device)
        
        with torch.no_grad():
            fext_out = fext(fext_img)
            fext_out = fext_normalize(fext_out)
            
        pdn_out = pdn(pdn_img)
        
        loss = criterion(pdn_out, fext_out)
        
#         if bix==0:
#             batch_loss = torch.tensor([loss.item()], requires_grad=True, device=device)
#         else:
#             batch_loss = torch.cat((batch_loss, torch.tensor([loss.item()], requires_grad=True, device=device)), dim=0)

        if bix==0:
            batch_loss = loss
        else:
            batch_loss = batch_loss + loss
    

    avg_loss = batch_loss/16

    
    if epoch==0 or (epoch+1)%200==0:
        print(f"Loss: {avg_loss}")
    
    optimizer.zero_grad()
    
    avg_loss.backward()
    
    optimizer.step()

Loss: 80.39542388916016
Loss: 81.05866241455078
Loss: 84.54275512695312
Loss: 74.74446105957031
Loss: 76.28498840332031
Loss: 76.0407485961914
Loss: 72.04757690429688
Loss: 73.52398681640625
Loss: 72.39143371582031
Loss: 79.05269622802734
Loss: 76.18132019042969
Loss: 74.90377044677734
Loss: 78.13765716552734
Loss: 68.6776351928711
Loss: 73.0262680053711
Loss: 73.55596923828125
Loss: 72.34915924072266
Loss: 73.3984375
Loss: 68.65929412841797
Loss: 77.50553131103516
Loss: 73.97562408447266
Loss: 75.40345001220703
Loss: 74.63668823242188
Loss: 73.56914520263672
Loss: 74.30719757080078
Loss: 76.3548583984375


In [9]:
torch.save(pdn.module.state_dict(), 'Teacher_model5000.pth')