In [1]:
from torch import nn, optim, cuda
import torch.nn.functional as F
from torchvision.transforms import transforms
from torchvision import datasets, models
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule, Trainer

In [5]:
def load_dataloader(name, transform, bs):
  if name == 'cifar10':
    train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
    val_ds = datasets.CIFAR10('./data', train=False, download=True, transform=transform)
    # num_workers is the number of threads for loading data.  Pytorch lightning complains if 
    # this is not larger than 1.  It is ok for it to be 1 in general for small problems.
    train_dl = DataLoader(train_ds, bs, shuffle=True, num_workers=2)
    val_dl = DataLoader(val_ds, bs, shuffle=False, num_workers=2)
    
    return train_dl, val_dl    

In [6]:
class Model(LightningModule):
  def __init__(self, model='resnet18', freeze=False):
    super().__init__()
    self.model = self._build_model(model, freeze)

  def _build_model(self, name, freeze):
    if name=='resnet18':
      model = models.resnet18(pretrained=True)

      if freeze==True:
        for params in model.parameters():
          params.requires_grad = False
        model.fc = nn.Linear(512, 10)
        return model
      else:
        model.fc = nn.Linear(512, 10)
        return model

  def forward(x):
    out = self.model(x)
    return out

  def training_step(self, batch, batch_idx):
    x, y = batch
    out = self.model(x)
    loss = F.cross_entropy(out, y)
    self.log("train_loss", loss, prog_bar=True)

    return loss

  def validation_step(self, batch, batch_idx):
    x, y = batch
    out = self.model(x)
    loss = F.cross_entropy(out, y)
    self.log("val_loss", loss, prog_bar=True)

    return loss

  def configure_optimizers(self):
    return optim.Adam(self.model.parameters())

In [7]:
trainloader, valloader = load_dataloader('cifar10', transforms.ToTensor(), 64)
model = Model('resnet18', freeze=True)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
if cuda.is_available():
    print('CUDA available')
    print('current device', cuda.current_device())
    print('device count', cuda.device_count())
    print('device name', cuda.get_device_name(0))
    print('arch list', cuda.get_arch_list())
    number_of_gpus=1
else:
    print('CUDA *not* available')
    number_of_gpus=0

CUDA *not* available


  return torch._C._cuda_getDeviceCount() > 0


In [9]:
trainer = Trainer(max_epochs=2,
                  fast_dev_run=True, # This runs just one batch to make sure things are running.  REMOVE FOR NORMAL RUNS!
                  gpus=number_of_gpus)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Running in fast_dev_run mode: will run a full train, val and test loop using 1 batch(es).


In [10]:
trainer.fit(model, trainloader, valloader)


  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
5.1 K     Trainable params
11.2 M    Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]