# Trialling Low Rank Optimisation

This is a small notebook and experiment to test Low Rank Optimisation, using the GaLore optimiser.

It's based on the [Pytorch's DCGAN tutorial](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html)

## Configuration

### Environment

We use [miniconda](https://docs.conda.io/en/latest/miniconda.html) (and we recommend also having [mamba](https://mamba.readthedocs.io/en/latest/installation.html#installation)) to setup the environment and we have only tested this way. If you want to install the packages through pip you are on your own.

We provide a simple setup script that checks your system for GPU and CUDA versions using `nvidia-smi`, (re-)creates the environment using `conda`, and installs the packages according to your system using either `mamba` or `conda`. To use this script to setup your environment, you only need to run:
```sh
chmod a+x loraopt_setup.sh
./loraopt_setup.sh
```
And then [add the environment's kernel to your jupyter notebook](https://arshren.medium.com/how-to-setup-conda-environments-and-add-kernels-for-jupyter-notebook-f2ebf968a409) in case your notebook engine doesn't find it automatically:
```
conda activate loraopt
python -m ipykernel install --user --name=loraopt --display-name "loraopt"
conda deactivate
jupyter notebook
```

Note that we only tested our environment on a NVIDIA RTX A4000 on Ubuntu 22.04.02, with driver version 535.183.01, NVIDIA-SMI 535.183.01, and CUDA Version 12.2. If you find any problems with other setups feel free to raise an issue.

All our `.yml` files contain major and minor versions for libraries, with debug versions of some of the libraries. If you can't find a combination for your system, try [relaxing the versions](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#create-env-file-manually). However, it might be possible that this notebook won't work with some of the relaxed versions.

### Folders and settings

The strings below define where we will search for and store data, results, images, and models. Please change it if you want to use other folders than the default ones. Some of the libraries we use might have library-specific defaults which can be shared among environments, and we keep those separate.

Also, we have some variables that define behavious along the script, for example defining which images extensions to save, whether to show images in the notebook, etc...

In [None]:
# Folders to use
data_folder = "~/data"
results_folder = "./loraopt/results"
models_folder = "./loraopt/models"
images_folder = "./loraopt/images"

# Autoencoder training hyperparameters
N_WORKERS = 2
BATCH_SIZE = 512
N_EPOCHS = 8
LEARNING_RATE = 2e-4
BETA1 = 0.5
IMAGE_SIZE = 64 #128
NUM_CHANNELS = 3
GAN_LATENT_DIM = 100
N_GEN_F = 64
N_DIS_F = 64

# GaLore params
GALORE_RANK = 4
GALORE_UPDATE_PROJ_GAP = 50
GALORE_SCALE = 0.25
GALORE_PROJ_TYPE = "std"

# Point-grid and KDE generative exploration hyperparameters
NUM_POINTS_GRID = 128
NUMBER_OF_SAMPLES_PER_CLASS = 4
TOP_PCT_TO_SAMPLE_FROM = 0.01
SOFTMAX_REPARAM_TEMPERATURE = 1

# Image variables
YLABEL_FONTSIZE = 6
COORDS_FONTSIZE = 8
SHOW_IMAGES = True
SAVE_IMAGES = True
# I do not recommend saving vectorial images, as they become quite large with the amount of points being plotted.
IMAGE_FORMATS = ["jpg", "png"]
SHOW_ACQUISITIONS = False

N_GPU = 1

In [None]:
import argparse

In [None]:
import os
import os.path as osp

data_folder, results_folder, models_folder, images_folder = map(
    osp.expanduser,
    map(
        osp.expandvars,
        (data_folder, results_folder, models_folder, images_folder)
    )
)

for f in [data_folder, results_folder, models_folder]:
    os.makedirs(f, exist_ok=True)

for fmt in IMAGE_FORMATS:
    os.makedirs(osp.join(images_folder,fmt), exist_ok=True)

### Imports

In [None]:
from more_itertools import interleave, take
from itertools import chain

In [None]:
from tqdm.autonotebook import tqdm

In [None]:
import random

import pandas as pd
import numpy as np
import scipy as sp
import scipy.stats as sps

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader
from torchvision.datasets import CelebA
import torchvision.transforms as tvtransforms
import torchvision.utils as tvutils

In [None]:
import tensorly as tl
tl.set_backend("pytorch")

from galore_torch import GaLoreAdamW
from galore_torch import GaLoreAdamW8bit

def make_galore_param_groups(model:nn.Module, rank=GALORE_RANK, update_proj_gap=GALORE_UPDATE_PROJ_GAP, scale=GALORE_SCALE, proj_type=GALORE_PROJ_TYPE):
    """
    Builds a GaLore param group list by capturing all linear layers and assigning them the appropriate rank
    and caputing all Conv2d and ConvTranspose2d layers and assigning them the rank on the first and second dimensions
    """
    galore_params = {
        module_type_str: [
            module.weight
            for module_name, module in model.named_modules()
            if isinstance(module, ModuleType)
        ] for module_type_str, ModuleType in (("conv2d", (nn.Conv2d,nn.ConvTranspose2d)), ("linear", nn.Linear),)
    }
    id_galore_params =  set(chain(*[[id(p) for p in galore_params[k]] for k in ("conv2d", "linear")]))
    regular_params = [p for p in model.parameters() if id(p) not in id_galore_params]
    param_groups = []
    param_groups.append({'params': regular_params}) 
    param_groups.append({'params': galore_params["linear"], 'rank': rank, 'update_proj_gap': update_proj_gap, 'scale': scale, 'proj_type': proj_type})
    
    conv_2d_weights_and_shapes = [(w, w.shape[-2:]) for w in galore_params["conv2d"]]
    conv_2d_unique_shapes = set([s for (w,s) in conv_2d_weights_and_shapes])
    for conv2d_shape in conv_2d_unique_shapes:
        param_groups.append({'params': [w for (w,s) in galore_params["conv2d"] if all(s==conv2d_shape)], 'rank': [rank,rank,conv2d_shape[0],conv2d_shape[1]], 'update_proj_gap': update_proj_gap, 'scale': scale, 'dim':4})

    return param_groups

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import seaborn as sns

In [None]:
# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
# Cannot use deterministic algorithms as either GaLore or tensorly require it to be disabled and I didn't want to set it as an environment variable on the conda env
#torch.use_deterministic_algorithms(True) # Needed for reproducible results

### Downloading the dataset

In [None]:
celeba_train = CelebA(
    root=data_folder,
    download=True,
    split="train",
    target_type="attr",
    transform=tvtransforms.Compose(
        [
            tvtransforms.Resize(IMAGE_SIZE),
            tvtransforms.CenterCrop(IMAGE_SIZE),
            tvtransforms.ToTensor(),
            tvtransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]),
    )

In [None]:
dataloader = torch.utils.data.DataLoader(celeba_train, batch_size=BATCH_SIZE,
                                         shuffle=True, num_workers=N_WORKERS)

In [None]:
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and N_GPU) else "cpu")

In [None]:
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(tvutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()

In [None]:
# custom weights initialization called on ``netG`` and ``netD``
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
# Generator Code

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( GAN_LATENT_DIM, N_GEN_F * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(N_GEN_F * 8),
            nn.ReLU(True),
            # state size. ``(ngf*8) x 4 x 4``
            nn.ConvTranspose2d(N_GEN_F * 8, N_GEN_F * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(N_GEN_F * 4),
            nn.ReLU(True),
            # state size. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( N_GEN_F * 4, N_GEN_F * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(N_GEN_F * 2),
            nn.ReLU(True),
            # state size. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( N_GEN_F * 2, N_GEN_F, 4, 2, 1, bias=False),
            nn.BatchNorm2d(N_GEN_F),
            nn.ReLU(True),
            # state size. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( N_GEN_F, NUM_CHANNELS, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. ``(nc) x 64 x 64``
        )

    def forward(self, input):
        return self.main(input)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is ``(nc) x 64 x 64``
            nn.Conv2d(NUM_CHANNELS, N_DIS_F, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf) x 32 x 32``
            nn.Conv2d(N_DIS_F, N_DIS_F * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(N_DIS_F * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*2) x 16 x 16``
            nn.Conv2d(N_DIS_F * 2, N_DIS_F * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(N_DIS_F * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*4) x 8 x 8``
            nn.Conv2d(N_DIS_F * 4, N_DIS_F * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(N_DIS_F * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*8) x 4 x 4``
            nn.Conv2d(N_DIS_F * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

### Regular ADAM

In [None]:
# Create the generator
netG = Generator(N_GPU).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (N_GPU > 1):
    netG = nn.DataParallel(netG, list(range(N_GPU)))

# Apply the ``weights_init`` function to randomly initialize all weights
#  to ``mean=0``, ``stdev=0.02``.
netG.apply(weights_init)

# Print the model
print(netG)

In [None]:
# Create the Discriminator
netD = Discriminator(N_GPU).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (N_GPU > 1):
    netD = nn.DataParallel(netD, list(range(N_GPU)))

# Apply the ``weights_init`` function to randomly initialize all weights
# like this: ``to mean=0, stdev=0.2``.
netD.apply(weights_init)

# Print the model
print(netD)


In [None]:
# Initialize the ``BCELoss`` function
criterion = nn.BCELoss()

In [None]:
# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, GAN_LATENT_DIM, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.AdamW(netD.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))
optimizerG = optim.AdamW(netG.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(N_EPOCHS):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, GAN_LATENT_DIM, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, N_EPOCHS, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == N_EPOCHS-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(tvutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

In [None]:
# nvtop: 96-97% 1914 MiB GPU 7m23.6s
#

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

### GaLore

In [None]:
# Create the generator
netG = Generator(N_GPU).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (N_GPU > 1):
    netG = nn.DataParallel(netG, list(range(N_GPU)))

# Apply the ``weights_init`` function to randomly initialize all weights
#  to ``mean=0``, ``stdev=0.02``.
netG.apply(weights_init)

# Print the model
print(netG)

In [None]:
# Create the Discriminator
netD = Discriminator(N_GPU).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (N_GPU > 1):
    netD = nn.DataParallel(netD, list(range(N_GPU)))

# Apply the ``weights_init`` function to randomly initialize all weights
# like this: ``to mean=0, stdev=0.2``.
netD.apply(weights_init)

# Print the model
print(netD)


In [None]:
# Initialize the ``BCELoss`` function
criterion = nn.BCELoss()

In [None]:
# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, GAN_LATENT_DIM, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = GaLoreAdamW(make_galore_param_groups(netD), lr=LEARNING_RATE, betas=(BETA1, 0.999))
optimizerG = GaLoreAdamW(make_galore_param_groups(netG), lr=LEARNING_RATE, betas=(BETA1, 0.999))

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(N_EPOCHS):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, GAN_LATENT_DIM, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, N_EPOCHS, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == N_EPOCHS-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(tvutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

In [None]:
# nvtop: % 2064 MiB GPU 7m0.2s

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

### Galore 8bit

In [None]:
# Create the generator
netG = Generator(N_GPU).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (N_GPU > 1):
    netG = nn.DataParallel(netG, list(range(N_GPU)))

# Apply the ``weights_init`` function to randomly initialize all weights
#  to ``mean=0``, ``stdev=0.02``.
netG.apply(weights_init)

# Print the model
print(netG)

In [None]:
# Create the Discriminator
netD = Discriminator(N_GPU).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (N_GPU > 1):
    netD = nn.DataParallel(netD, list(range(N_GPU)))

# Apply the ``weights_init`` function to randomly initialize all weights
# like this: ``to mean=0, stdev=0.2``.
netD.apply(weights_init)

# Print the model
print(netD)


In [None]:
# Initialize the ``BCELoss`` function
criterion = nn.BCELoss()

In [None]:
# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, GAN_LATENT_DIM, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = GaLoreAdamW8bit(make_galore_param_groups(netD), lr=LEARNING_RATE, betas=(BETA1, 0.999))
optimizerG = GaLoreAdamW8bit(make_galore_param_groups(netG), lr=LEARNING_RATE, betas=(BETA1, 0.999))

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(N_EPOCHS):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, GAN_LATENT_DIM, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, N_EPOCHS, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == N_EPOCHS-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(tvutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

In [None]:
import tensorly as tl

In [None]:
weights = torch.randn([256, 128, 4, 4], dtype=torch.float16, device="cuda")
weights.shape

In [None]:
matrix = weights.data.float()
matrix.shape

In [None]:
tucker_tensor = tl.decomposition.tucker(matrix, rank=4)
tucker_tensor

In [None]:
tucker_tensor.core.shape

In [None]:
[f.shape for f in tucker_tensor.factors]

aaaaaaaaaaaaa

In [None]:
weights = torch.randn([256, 128], dtype=torch.float16, device="cuda")
weights.shape

In [None]:
matrix = weights.data.float()
matrix.shape

In [None]:
tucker_tensor = tl.decomposition.tucker(matrix, rank=[4,1,1,1])
tucker_tensor, tucker_tensor.core.shape, [f.shape for f in tucker_tensor.factors]

aaaaaaaaaaaaaaa

In [None]:
weights = torch.randn([256, 128, 4, 4], dtype=torch.float16, device="cuda")
weights = weights.reshape([256, -1])
weights.shape

In [None]:
matrix = weights.data.float()
matrix.shape

In [None]:
tucker_tensor = tl.decomposition.tucker(matrix, rank=[4, 1])
tucker_tensor, tucker_tensor.core.shape, [f.shape for f in tucker_tensor.factors]

In [None]:
# nvtop: 94-96% 1916 MiB GPU m.s

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())