In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
from utils import *
from model_training import DoodleDataset, RealDataset

In [3]:
class V2ConvNet(nn.Module):
    def __init__(self, in_c, 
                 num_classes, 
                 channel_list=[64, 128, 192, 256, 512], 
                 pool_option=(1,1), 
                 hidden=256, 
                 dropout=0.2, 
                 add_layers=False):
        super().__init__()
        
        layer1 = nn.Conv2d(in_c, channel_list[0], kernel_size=3)
        layer2 = nn.Conv2d(channel_list[0], channel_list[0], kernel_size=3)
        layers = [layer1, layer2]
        
        for i in range(1, len(channel_list)):
            layers.append(
                nn.Conv2d(channel_list[i-1], channel_list[i], kernel_size=3, stride=2, padding=1, bias=True)
            )
            layers.append(
                nn.Conv2d(channel_list[i], channel_list[i], kernel_size=3, stride=2, padding=1, bias=True)
            )
            layers.append(
                nn.BatchNorm2d(channel_list[i])
            )
            layers.append(
                nn.Dropout(dropout)
            )
            layers.append(nn.ReLU())
            
        self.conv = nn.Sequential(*layers)
        
        self.flatten = nn.AdaptiveAvgPool2d(pool_option)
            
        self.fc = nn.Sequential(*[
            nn.Linear(pool_option[0] * pool_option[1] * channel_list[-1], hidden),
            nn.Linear(hidden, num_classes)
        ])

    def forward(self, x, return_feats=False):
        feats = self.conv(x)
        x = x.view(x.size(0), 512, -1).mean(2)
        x = self.fc(x)

        if return_feats:
            return x, feats

        return x

In [4]:
from torchinfo import summary

x = torch.rand(100, 3, 64, 64)
net = V2ConvNet(3, 9, [64, 128, 192, 256, 512])
y = net(x)
print (y.shape)
print (summary(net))

torch.Size([100, 9])
Layer (type:depth-idx)                   Param #
V2ConvNet                                --
├─Sequential: 1-1                        --
│    └─Conv2d: 2-1                       1,792
│    └─Conv2d: 2-2                       36,928
│    └─Conv2d: 2-3                       73,856
│    └─Conv2d: 2-4                       147,584
│    └─BatchNorm2d: 2-5                  256
│    └─Dropout: 2-6                      --
│    └─ReLU: 2-7                         --
│    └─Conv2d: 2-8                       221,376
│    └─Conv2d: 2-9                       331,968
│    └─BatchNorm2d: 2-10                 384
│    └─Dropout: 2-11                     --
│    └─ReLU: 2-12                        --
│    └─Conv2d: 2-13                      442,624
│    └─Conv2d: 2-14                      590,080
│    └─BatchNorm2d: 2-15                 512
│    └─Dropout: 2-16                     --
│    └─ReLU: 2-17                        --
│    └─Conv2d: 2-18                      1,180,160
│   

In [5]:
real_train_set = RealDataset(train=True)
real_val_set = RealDataset(train=False)

real_train_loader = torch.utils.data.DataLoader(real_train_set, batch_size=512, shuffle=True)
real_val_loader = torch.utils.data.DataLoader(real_val_set, batch_size=512)

In [7]:
criterion = nn.CrossEntropyLoss()
model = V2ConvNet(3, 9, [64, 128, 192, 512])
optim = torch.optim.AdamW(model.parameters())

In [11]:
fix_seed(0)  # zero seed by default
os.environ["CUDA_VISIBLE_DEVICES"] = "2, 3"

model = nn.DataParallel(model).cuda()

In [12]:
def get_accuracy(pred, label):
    pred, label = pred.cpu(), label.cpu()
    return (pred.argmax(1) == label).sum().item() / len(label)

In [13]:
epochs= 20

for epoch in range(epochs):
    total_loss = 0
    model.train()
    count = 0
    for idx, (x, y) in enumerate(real_train_loader):
        count += 1
        optim.zero_grad()
        
        x, y = x.cuda(), y.cuda()
        pred = model(x)
        loss = criterion(pred, y)

        total_loss += loss.detach().cpu().item()
        loss.backward()
        optim.step()
    
    epoch_loss = total_loss / count
    
    total_val_acc = 0
    with torch.no_grad():
        model.eval()
        count = 0
        for idx, (x, y) in enumerate(real_val_loader):
            count += 1
            pred = model(x)
            val_acc = get_accuracy(pred, y)
            total_val_acc += val_acc
            
        avg_val_acc = total_val_acc / count
        
    print ("Epoch: {}, Training Loss: {}, Avg. Validation Accuracy: {}".format(epoch, epoch_loss, avg_val_acc))

RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 3; 23.70 GiB total capacity; 0 bytes already allocated; 20.56 MiB free; 0 bytes reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF