In [1]:
from functools import partial
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as v2
from torch.utils.data import random_split, DataLoader
from ray import tune, train
import tempfile
import torchDatasets as ds
from ray.train import Checkpoint
from ray.tune.schedulers import ASHAScheduler
# from torchvision.models import resnet50, resnet18
import networks as custNN
import matplotlib.pyplot as plt

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"

modelName = "Autoencoder"; side = 32
base_dir = '/home/shashank/Code/gonihedric/'; dataDir = base_dir + "data/"
trainScheme = "Autoencoder"
criterion = nn.BCELoss()

# class ConvAutoencoder(nn.Module):
#     def __init__(self, latent_dim=8, dropout=0.2, internal_activaton=nn.ReLU(), output_activation=nn.Sigmoid()):
#         super(ConvAutoencoder, self).__init__()

#         self.drpt = nn.Dropout(dropout)

#         # ----- Encoder -----
#         self.encoder = nn.Sequential(
#             nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # -> (32, H/2, W/2)
#             # nn.BatchNorm2d(32),
#             internal_activaton,

#             # nn.Dropout(dropout),
#             nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # -> (64, H/4, W/4)
#             # nn.BatchNorm2d(64),
#             internal_activaton,

#             # nn.Dropout(dropout),
#             nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),# -> (128, H/8, W/8)
#             # nn.BatchNorm2d(128),
#             internal_activaton,
#         )

#         # Bottleneck (latent space)
#         self.fc_enc = nn.Linear(128 * 4 * 4, latent_dim)   # assumes input = 32x32
#         self.fc_dec = nn.Linear(latent_dim, 128 * 4 * 4)

#         # ----- Decoder -----
#         self.decoder = nn.Sequential(
#             # nn.Dropout(dropout),
#             nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # -> (64, H/4, W/4)
#             # nn.BatchNorm2d(64),
#             internal_activaton,

#             # nn.Dropout(dropout),
#             nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # -> (32, H/2, W/2)
#             # nn.BatchNorm2d(32),
#             internal_activaton,

#             # nn.Dropout(dropout),
#             nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # -> (1, H, W)
#             output_activation
#         )

#     def forward(self, x):
#         # Encode
#         x = self.encoder(x)
#         x = x.view(x.size(0), -1)      # flatten for FC
#         x = self.drpt(x)
#         z = self.fc_enc(x)

#         # Decode
#         z = self.drpt(z)
#         x = self.fc_dec(z)
#         x = x.view(x.size(0), 128, 4, 4)  # reshape back
#         x = self.decoder(x)
#         return x

class ConvAutoencoder(nn.Module):
    def __init__(self, latent_dim=8, dropout=0.2):
        super(ConvAutoencoder, self).__init__()

        self.drpt = nn.Dropout(dropout)

        # ----- Encoder -----
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, padding_mode='circular'),  # -> (32, H/2, W/2)
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            nn.Dropout(dropout),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # -> (64, H/4, W/4)
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.Dropout(dropout),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),# -> (128, H/8, W/8)
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )

        # Bottleneck (latent space)
        self.fc_enc = nn.Linear(128 * 4 * 4, latent_dim)   # assumes input = 32x32
        self.fc_dec = nn.Linear(latent_dim, 128 * 4 * 4)

        # ----- Decoder -----
        self.decoder = nn.Sequential(
            nn.Dropout(dropout),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # -> (64, H/4, W/4)
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.Dropout(dropout),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # -> (32, H/2, W/2)
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            nn.Dropout(dropout),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # -> (1, H, W)
            nn.Sigmoid()  # keeps output in [0,1] for binary images
        )

    def forward(self, x):
        # Encode
        x = self.encoder(x)
        x = x.view(x.size(0), -1)      # flatten for FC
        x = self.drpt(x)
        z = self.fc_enc(x)

        # Decode
        z = self.drpt(z)
        x = self.fc_dec(z)
        x = x.view(x.size(0), 128, 4, 4)  # reshape back
        x = self.decoder(x)
        return x


class ReshapeTransform:
    def __init__(self, shape):
        self.shape = shape
    def __call__(self, x):
        return x.view(*self.shape)

def load_data(data_dir, config):
    # model, transform = custNN.modelPicker(modelName, side, nTargets, data_dir)
    #paper = [900, 750, 600, 450, 300, 150, 75, 30, 10, 2]
    # model = custNN.Autoencoder([900, 750, 600, 450, 300, 150, 75, 30, config["latentSpace"]], nn.Sigmoid(), nn.Sigmoid())
    # model = custNN.Autoencoder([1024, 750, 600, 450, 300, 150, 75, 30, 2], nn.ReLU(), nn.Sigmoid())
    # # model = AutoencoderCNN(2)
    model = ConvAutoencoder(config["latentSpace"],0.2)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    else:
        model = model.to(device)
    
    # transform = v2.Compose([v2.Lambda(lambda x: 2*x - 1)])
    transform = v2.RandomVerticalFlip()
    # transform = v2.Compose([v2.RandomHorizontalFlip(),
    #                         v2.RandomVerticalFlip(),
    #                         ReshapeTransform((1,side,side))])# None 
    trainset = ds.CustomAutoencoderDataset(data_dir+"train2DGH32", side, transform) # 2DGH32 gnhd2dTest
    testset = ds.CustomAutoencoderDataset(data_dir+"test2DGH32", side, transform) # 2DGH32 gnhd2dTrain
    return model, trainset, testset

def initialize_weights(model):
    nn.init.normal_(model.conv1.weight, 0, 0.1)
    for m in model.modules():
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.1)
            nn.init.constant_(m.bias, 0)


def trainFN(config):

    net, trainset, _ = load_data(dataDir, config)
    custNN.initialize_weights(net)
    
    # optimizer = optim.SGD(net.parameters(),lr=config["lr"],
    #                       momentum=config["momentum"],weight_decay=config["wd"])
    optimizer = optim.Adam(net.parameters(),lr=config["lr"],weight_decay = config["wd"])
    # optimizer = optim.Adadelta(net.parameters(), lr=config["lr"], weight_decay=config["wd"])

    exp_lr_scheduler = None# optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=config["gamma"])

    # Load existing checkpoint through `get_checkpoint()` API.
    if train.get_checkpoint():
        loaded_checkpoint = train.get_checkpoint()
        with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
            model_state, optimizer_state = torch.load(
                os.path.join(loaded_checkpoint_dir, "checkpoint.pt")
            )
            net.load_state_dict(model_state)
            optimizer.load_state_dict(optimizer_state)

    test_abs = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(
        trainset, [test_abs, len(trainset) - test_abs])

    trainloader = DataLoader( train_subset,
                             batch_size=128,
                             shuffle=True,
                             num_workers=4)
    valloader = DataLoader( val_subset,
                           batch_size=128,
                           shuffle=True,
                           num_workers=4)

    for epoch in range(50):  # loop over the dataset multiple times
        running_loss = 0.0
        epoch_steps = 0
        net.train(); i = 0
        for inputs,_ in trainloader:
            # get the inputs; data is a list of [inputs, labels]
            inputs = inputs.to(device)

            optimizer.zero_grad() # zero the parameter gradients
            # forward + backward + optimize
            outputs = net(inputs).squeeze()
            # outputs = outputs.float(); labels = labels.float()
            loss = criterion(outputs, inputs.squeeze())
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_steps += 1
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print(
                    "[%d, %5d] loss: %.3f"
                    % (epoch + 1, i + 1, running_loss / epoch_steps)
                )
                epoch_steps = 0
                running_loss = 0.0
                if exp_lr_scheduler is not None:
                    exp_lr_scheduler.step()
            i += 1

        # Validation loss
        val_loss = 0.0
        val_steps = 0
        # correct = 0
        net.eval()
        with torch.no_grad():
            for inputs,_ in valloader:
                inputs = inputs.to(device)
                outputs = net(inputs)
                outputs = outputs.detach().squeeze().float()
                val_loss += criterion(outputs, inputs.squeeze()).cpu().numpy()
                # predicted = torch.max(F.softmax(outputs, dim=1), 1).indices
                # correct += (predicted == labels.max(1).indices).sum().div(torch.numel(predicted)).item()
                val_steps += 1

# Here we save a checkpoint. It is automatically registered with
        # Ray Tune and will potentially be accessed through in ``get_checkpoint()``
        # in future iterations.
        # Note to save a file like checkpoint, you still need to put it under a directory
        # to construct a checkpoint.
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            path = os.path.join(temp_checkpoint_dir, "checkpoint.pt")
            torch.save(
                (net.state_dict(), optimizer.state_dict()), path
            )
            checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
            train.report(
                {"loss": (val_loss / val_steps), 
                #  "accuracy": (correct / val_steps)
                 },
                checkpoint=checkpoint,
            )
    print("Finished Training")

def test_best_model(best_result):
    best_trained_model, _, testset = load_data(dataDir, best_result.config)

    checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")

    model_state, _ = torch.load(checkpoint_path)
    best_trained_model.load_state_dict(model_state)

    testloader = DataLoader( testset, batch_size=128, 
                            shuffle=False, num_workers=4)

    correct = 0; total = 0; loss = 0
    best_trained_model.eval()
    with torch.no_grad():
        for inputs,_ in testloader:
            inputs = inputs.to(device)
            outputs = best_trained_model(inputs)
            outputs = outputs.detach().squeeze().float()
            loss += criterion(outputs, inputs.squeeze()).cpu().numpy()
            # predicted = torch.max(F.softmax(outputs, dim=1), 1).indices
            # correct += (predicted == labels.max(1).indices).sum().div(torch.numel(predicted)).item()
            total += 1
    # print("Best trial test set accuracy for \"{}\": {}".format(trainScheme, correct/total))
    print("Best trial test set loss for \"{}\": {}".format(trainScheme, loss/total))

def main(num_samples=10, max_num_epochs=10, cpus_per_trial=6, gpus_per_trial=2):
    config = {
    "lr": tune.loguniform(1e-1, 5e-5),
    "latentSpace": tune.choice([2, 4, 8, 16]),
    # "batch_size": tune.choice([128, 256]),
    "wd": tune.loguniform(1e-1, 1e-5),
    # "momentum": tune.uniform(0.1, 1.0),
    # "amsgrad": tune.choice([True, False]),
    # "dropout": tune.uniform(0.0, 0.5),
    }
    print(modelName + " with SGD for OP side:"+str(side))
    scheduler = ASHAScheduler(
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2
    )
    tuner = tune.Tuner(
         tune.with_resources(
            tune.with_parameters(trainFN),
            resources={"cpu": cpus_per_trial, "gpu": gpus_per_trial}
        ),
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
            scheduler=scheduler,
            num_samples=num_samples,
        ),
        param_space=config,
    )
    results = tuner.fit()
    best_result = results.get_best_result("loss", "min")

    print("Best trial config: {}".format(best_result.config))
    print("Best trial final validation loss: {}".format(
        best_result.metrics["loss"]))
    # if nTargets > 1:
    #     print("Best trial final validation accuracy: {}".format(
    #         best_result.metrics["accuracy"]))
    test_best_model(best_result)

In [2]:
main(num_samples=20, max_num_epochs=100, gpus_per_trial=1)

0,1
Current time:,2025-10-06 17:18:57
Running for:,00:00:10.49
Memory:,8.7/31.3 GiB

Trial name,# failures,error file
trainFN_c1040_00000,1,"/tmp/ray/session_2025-10-06_17-18-44_628543_71667/artifacts/2025-10-06_17-18-46/trainFN_2025-10-06_17-18-44/driver_artifacts/trainFN_c1040_00000_0_latentSpace=2,lr=0.0006,wd=0.0000_2025-10-06_17-18-46/error.txt"
trainFN_c1040_00001,1,"/tmp/ray/session_2025-10-06_17-18-44_628543_71667/artifacts/2025-10-06_17-18-46/trainFN_2025-10-06_17-18-44/driver_artifacts/trainFN_c1040_00001_1_latentSpace=4,lr=0.0146,wd=0.0325_2025-10-06_17-18-46/error.txt"
trainFN_c1040_00002,1,"/tmp/ray/session_2025-10-06_17-18-44_628543_71667/artifacts/2025-10-06_17-18-46/trainFN_2025-10-06_17-18-44/driver_artifacts/trainFN_c1040_00002_2_latentSpace=2,lr=0.0002,wd=0.0000_2025-10-06_17-18-46/error.txt"

Trial name,status,loc,latentSpace,lr,wd
trainFN_c1040_00003,PENDING,,2,0.00322037,1.37085e-05
trainFN_c1040_00004,PENDING,,2,0.000662406,0.000122391
trainFN_c1040_00005,PENDING,,16,0.00523295,0.00418077
trainFN_c1040_00006,PENDING,,16,0.0950636,0.000678475
trainFN_c1040_00007,PENDING,,4,0.000173013,0.00899187
trainFN_c1040_00008,PENDING,,8,0.00325957,0.000142657
trainFN_c1040_00009,PENDING,,4,0.000130971,0.00175306
trainFN_c1040_00010,PENDING,,16,0.000301984,0.0202366
trainFN_c1040_00011,PENDING,,2,0.00061134,8.52108e-05
trainFN_c1040_00012,PENDING,,2,0.0314486,0.000658477


2025-10-06 17:18:49,710	ERROR tune_controller.py:1331 -- Trial task failed for trial trainFN_c1040_00000
Traceback (most recent call last):
  File "/home/shashank/miniconda3/envs/neuralnets/lib/python3.11/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
             ^^^^^^^^^^^^^^^
  File "/home/shashank/miniconda3/envs/neuralnets/lib/python3.11/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/shashank/miniconda3/envs/neuralnets/lib/python3.11/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/shashank/miniconda3/envs/neuralnets/lib/python3.11/site-packages/ray/_private/worker.py", line 2753, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
                                  ^^^^

RuntimeError: No best trial found for the given metric: loss. This means that no trial has reported this metric, or all values reported for this metric are NaN. To not ignore NaN values, you can set the `filter_nan_and_inf` arg to False.

In [19]:
main(num_samples=20, max_num_epochs=100, gpus_per_trial=1)

0,1
Current time:,2025-10-06 11:02:08
Running for:,00:02:09.46
Memory:,8.7/31.3 GiB

Trial name,status,loc,latentSpace,lr,wd,iter,total time (s),loss
trainFN_d6512_00000,TERMINATED,134.109.17.190:61138,8,0.00751988,0.00112144,50,22.9955,0.409437
trainFN_d6512_00001,TERMINATED,134.109.17.190:62711,2,0.00107592,0.000130657,2,1.32387,0.606898
trainFN_d6512_00002,TERMINATED,134.109.17.190:62835,4,0.0643911,0.0593757,2,1.27687,0.698435
trainFN_d6512_00003,TERMINATED,134.109.17.190:62960,2,0.000226522,0.00424682,1,0.875912,0.705201
trainFN_d6512_00004,TERMINATED,134.109.17.190:63054,16,0.0270434,0.00298424,2,1.2869,0.75934
trainFN_d6512_00005,TERMINATED,134.109.17.190:63179,8,0.00766128,2.23199e-05,1,0.884244,0.697429
trainFN_d6512_00006,TERMINATED,134.109.17.190:63273,8,9.20462e-05,0.0145525,2,1.27778,0.695723
trainFN_d6512_00007,TERMINATED,134.109.17.190:63397,8,0.000220383,7.74112e-05,4,2.23773,0.567568
trainFN_d6512_00008,TERMINATED,134.109.17.190:63582,8,7.70759e-05,0.000955027,1,0.88641,0.697493
trainFN_d6512_00009,TERMINATED,134.109.17.190:63676,2,0.0847155,0.000468349,1,0.821227,0.723539


2025-10-06 11:02:08,768	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/shashank/ray_results/trainFN_2025-10-06_10-59-59' in 0.0050s.
2025-10-06 11:02:08,772	INFO tune.py:1041 -- Total run time: 129.47 seconds (129.45 seconds for the tuning loop).


Best trial config: {'lr': 0.007519875551227198, 'latentSpace': 8, 'wd': 0.0011214449822134858}
Best trial final validation loss: 0.40943703055381775
Best trial test set loss for "Autoencoder": 0.38199755549430847


In [36]:
model = ConvAutoencoder(8).to(device)
# model = custNN.Autoencoder([900, 750, 600, 450, 300, 150, 75, 30, 10, 2], nn.Tanh(), nn.Tanh())
# model.load_state_dict(torch.load(data_dir+f"checkpoints/Autoencoder/"+date+"/2/model_epoch_60.pth", map_location=device))
# model.load_state_dict(torch.load(data_dir+f"checkpoints/Autoencoder/model.pth", map_location=device))

side = 32
class ReshapeTransform:
    def __init__(self, shape):
        self.shape = shape
    def __call__(self, x):
        return x.view(*self.shape)

def visualize_reconstruction(model, data_loader, side=30):
    model.eval()
    with torch.no_grad():
        images, _ = next(iter(data_loader))
        images = images.to(device)
        reconstructed = model(images)

        # Plot original vs reconstructed images
        fig, axes = plt.subplots(2, 8, figsize=(15, 4))
        for i in range(8):
            # Original images
            axes[0,i].imshow(images[i].cpu().numpy().squeeze().reshape(side, side), cmap='gray')
            axes[0,i].axis('off')

            # Reconstructed images
            axes[1,i].imshow(reconstructed[i].cpu().numpy().squeeze().reshape(side, side), cmap='gray')
            axes[1,i].axis('off')

        plt.tight_layout()
        plt.show()
        


transform = ReshapeTransform((1, side, side)) # v2.Compose([v2.Lambda(lambda x: 2*x - 1)]) # None #
trainset = ds.CustomAutoencoderDataset(dataDir+"testPUD", side, transform)
data_loader = torch.utils.data.DataLoader(trainset, batch_size=10, shuffle=True)

visualize_reconstruction(model, data_loader, side)
# visualize_reconstruction(model,  smallTrainLoader)

    

RuntimeError: running_mean should contain 16 elements not 64

In [13]:
transform = None
dataset = ds.CustomAutoencoderDataset(dataDir+"small2DGH32", side, transform) # 2DGH32 gnhd2dTest
data_loader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True)
# model = ConvAutoencoder(8).to(device)
# model.eval()
with torch.no_grad():
    images, _ = next(iter(data_loader))
    # images = images.to(device)
    # reconstructed = model(images)
    print(images[0].min())

tensor(0.)


In [4]:
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt

import torch
from torchvision.transforms import v2

plt.rcParams["savefig.bbox"] = 'tight'

# if you change the seed, make sure that the randomly-applied transforms
# properly show that the image can be both transformed and *not* transformed!
torch.manual_seed(0)

# If you're trying to run that on Colab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/

orig_img = Image.open(Path('../assets') / 'astronaut.jpg')

FileNotFoundError: [Errno 2] No such file or directory: '../assets/astronaut.jpg'

In [1]:
import torch
import torch.nn as nn
import networks as custNN

In [2]:
model = custNN.Autoencoder([900, 750, 600, 450, 300, 150, 75, 30, 10, 2], nn.Tanh(), nn.Tanh())
print(model)

Autoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=900, out_features=750, bias=True)
    (1): Tanh()
    (2): Linear(in_features=750, out_features=600, bias=True)
    (3): Tanh()
    (4): Linear(in_features=600, out_features=450, bias=True)
    (5): Tanh()
    (6): Linear(in_features=450, out_features=300, bias=True)
    (7): Tanh()
    (8): Linear(in_features=300, out_features=150, bias=True)
    (9): Tanh()
    (10): Linear(in_features=150, out_features=75, bias=True)
    (11): Tanh()
    (12): Linear(in_features=75, out_features=30, bias=True)
    (13): Tanh()
    (14): Linear(in_features=30, out_features=10, bias=True)
    (15): Tanh()
    (16): Linear(in_features=10, out_features=2, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=2, out_features=10, bias=True)
    (1): Tanh()
    (2): Linear(in_features=10, out_features=30, bias=True)
    (3): Tanh()
    (4): Linear(in_features=30, out_features=75, bias=True)
    (5): Tanh()
    (6): Linear(in_fe

In [3]:
5e-4

0.0005

In [19]:
from utilsTrainTest import visualize_reconstruction
import torchvision.datasets as datasets
import torchvision.transforms as v2
side = 28
model = custNN.Autoencoder([784, 600, 450, 300, 150, 75, 30, 10, 2], nn.Tanh(), nn.Tanh())
model.load_state_dict(torch.load("/home/shashank/Code/gonihedric/data/checkpoints/modelSecond.pth", map_location=device))
model = model.to(device)
model.eval()

transform = v2.Compose([
    v2.ToTensor(),
    v2.Lambda(lambda x: 2*x - 1),
    ReshapeTransform((1, side*side))
    # v2.Lambda(lambda x: torch.flatten(x, start_dim=1)),  # Flatten the image
    # v2.Normalize((0.1307,), (0.3081,)),
    # v2.Normalize((0.5,), (0.5,)),
    # v2.Lambda(lambda x: x.view(-1) - 0.5)
])

train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    transform=transform,
    download=True
)
test_dataset = datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)
batchSize = 10
train_loader = DataLoader(dataset=train_dataset, batch_size=batchSize, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batchSize, shuffle=False)
visualize_reconstruction(model, device, train_loader, side, location="/home/shashank/Code/gonihedric/data/checkpoints")
visualize_reconstruction(model, device, test_loader, side, location="/home/shashank/Code/gonihedric/data/checkpoints")

In [3]:
28*28

784