In [1]:
{
    "role": "user",
    "content": "I'm trying to decide whether to take another bootcamp."
}

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


class ModelConfig:
    MODEL_CONFIG_THINK_TOKENS = {
        "begin_think": "<think>",
        "end_think": "</think>",
        "generate_kwargs": {
            "temperature": 0.6,
            "top_k": 20,
            "min_p": 0.0,
            "top_p": 0.95,
        },
    }
    MODEL_CONFIG_GPT_OSS_20B = {
        "begin_think": "<|end|><|start|>assistant<|channel|>final<|message|>analysis<|message|>",
        "end_think": "<|end|><|start|>assistant<|channel|>final<|message|>",
        "generate_kwargs": {
            "temperature": 0.6,
            "top_k": 20,
            "min_p": 0.0,
            "top_p": 0.95,
        },
    }
    MODEL_CONFIG_GEMMA = {
        "fuzzy_end_think_list": ["Answer:"],
        "generate_kwargs": {
            "repetition_penalty": 1.2,
            "temperature": 0.7,
            "top_k": 20,
            "min_p": 0.0,
            "top_p": 0.95,
        },
    }
    MODEL_CONFIG_LLAMA = {
        "fuzzy_end_think_list": ["Answer:"],
        "generate_kwargs": {
            "temperature": 0.6,
            "top_k": 20,
            "min_p": 0.0,
            "top_p": 0.95,
        },
    }

    SUPPORTED_MODELS = {
        "Qwen/Qwen2.5-0.5B": MODEL_CONFIG_THINK_TOKENS,
        "Qwen/Qwen3-0.6B": MODEL_CONFIG_THINK_TOKENS,
        "Qwen/Qwen3-1.7B": MODEL_CONFIG_THINK_TOKENS,
        "Qwen/Qwen3-4B": MODEL_CONFIG_THINK_TOKENS,
        "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B": MODEL_CONFIG_THINK_TOKENS,
        "Wladastic/Mini-Think-Base-1B": MODEL_CONFIG_GEMMA,
        "google/gemma-2-2b-it": MODEL_CONFIG_GEMMA,
        "openai/gpt-oss-20b": MODEL_CONFIG_GPT_OSS_20B,
        "meta-llama/Meta-Llama-3-8B-Instruct": MODEL_CONFIG_LLAMA,
        "meta-llama/Llama-2-7b-chat-hf": MODEL_CONFIG_LLAMA,
    }

    @staticmethod
    def get(model_name: str):
        if model_name not in ModelConfig.SUPPORTED_MODELS:
            print(f"ERROR: model {model_name} not supported")
            exit(1)
        return ModelConfig.SUPPORTED_MODELS[model_name]


class ModelPromptBuilder:
    def __init__(self, model_name: str, invokes_cot: bool = True):
        self.model_name = model_name
        self.invokes_cot = invokes_cot
        self.question = None

        # default to making a new assistant role section
        self.continue_final_message = False
        self.add_generation_prompt = True
        self.history = []

    def get_model_custom_instruction(self):
        please_write_answer = "Please write the string \"Answer: \" before the final answer."

        if self.model_name == "google/gemma-2-2b-it":
            return please_write_answer
        if self.model_name == "meta-llama/Meta-Llama-3-8B-Instruct" or self.model_name == "meta-llama/Llama-2-7b-chat-hf":
            return please_write_answer

        return None

    def add_system_instruction(self, system_instruction: str):
        self.add_to_history("system", system_instruction)

    def add_user_message(self, question: str, custom_instruction: str = None):
        self.question = question
        if custom_instruction is None:
            custom_instruction = "Let's think step by step."
        model_custom_instruction = self.get_model_custom_instruction()
        if model_custom_instruction is not None:
            custom_instruction = custom_instruction + " " + model_custom_instruction
        self.add_to_history("user", f"Question: {question}\n{custom_instruction}")

    def add_to_history(self, role: str, content: str):
        assert self.continue_final_message == False

        self.history.append({
            "role": role,
            "content": content
        })

    def add_partial_to_history(self, role: str, content: str):
        assert self.continue_final_message == False

        self.history.append({
            "role": role,
            "content": content
        })
        self.continue_final_message = True
        self.add_generation_prompt = False

    def add_think_token(self):
        model_config = ModelConfig.get(self.model_name)
        if "begin_think" in model_config:
            if (self.model_name == "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
                self.add_partial_to_history("assistant", "<think>")
            elif (self.model_name == "openai/gpt-oss-20b"):
                self.add_partial_to_history("assistant", "analysis")
        elif "fuzzy_end_think_list" in model_config:
            pass
        else:
            print(f"ERROR: model {self.model_name} missing CoT separator config")
            exit(1)

    def make_prompt(self, tokenizer):
        if self.invokes_cot:
            self.add_think_token()
        return self._apply_chat_template(tokenizer)

    def _apply_chat_template(self, tokenizer):
        prompt = tokenizer.apply_chat_template(self.history,
                                               tokenize=False,
                                               add_generation_prompt=self.add_generation_prompt,
                                               continue_final_message=self.continue_final_message)
        return prompt


def load_tokenizer(model_name: str, cache_dir: str = "/tmp/cache"):
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        cache_dir=cache_dir,
        trust_remote_code=True,
    )

    # Set pad token if not already set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    return tokenizer


def load_model(model_name: str, cache_dir: str = "/tmp/cache"):
    # todo load the tokenizer and model
    tokenizer = load_tokenizer(model_name=model_name, cache_dir = cache_dir)
    model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir = cache_dir)
    return (model, tokenizer)


  from .autonotebook import tqdm as notebook_tqdm


In [None]:

# %%
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from tqdm import tqdm
import random

# %%
# 1. Load the model and tokenizer
model_name="openai-community/gpt2"
print(f"Loading model: {model_name}...")
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()

def get_next_logits(input_ids):
    """
    Get the logits for the next token given input_ids.
    """
    assert input_ids.ndim == 2, "Input IDs should be a 2D tensor (batch_size, sequence_length)"
    with torch.no_grad():
        outputs = model(input_ids)
        return outputs.logits[:, -1, :]

# Set pad token if it's not set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
# TODO: Implement the model dimension extraction method
#   1. Load the model and tokenizer
(model, tokenizer) = load_model(model_name, "/tmp/cache-tokenizer")
#   2. Make n_queries (1k should be enough) random queries to collect logitsf
logits_list = []
for i in range(1000):
    input_len = random.randint(5, 15)
    token_list = []
    for n in range(input_len):
      token_list.append(random.randint(1, 50256))
    tokens = torch.tensor([token_list])
    output = model.forward(tokens)

    #decoded_output = tokenizer.decode(output.sequences[0])
    #logits = output.scores[0]
    logits = output.logits[0,-1]
    logits_list.append(logits)
#   3. Stack logits into a matrix Q
logits_matrix = torch.stack(logits_list, dim = 0)
print(logits_matrix.shape)
#   4. Compute SVD of Q
#   5. Find the "elbow" in singular values to estimate dimension
#   6. Plot the results


Loading model: openai-community/gpt2...


In [5]:
#   4. Compute SVD of Q
_, S, _ = torch.svd(logits_matrix)
print(S)
#   5. Find the "elbow" in singular values to estimate dimension
#   6. Plot the results

tensor([57974.9297,   690.5708,   458.9148,   373.8094,   353.3015,   316.1564,
          295.4497,   284.7804,   247.3152,   231.1299],
       grad_fn=<LinalgSvdBackward0>)


In [None]:
output.logits.shape

torch.Size([1, 5, 50257])

In [1]:
from typing import Callable

import torch
import torch.nn as nn


class UNet(nn.Module):
    """U-net architecture for 2D or 3D data."""
    def __init__(
        self,
        ndim: int = 2,
        activation_fn: Callable = nn.ReLU,
        activation_kwargs: dict = dict(inplace=True),
        dropout: float = 0,
        depth=3,
        n_in_channels=1,
        out_channels=1,
        mult_chan=64,
    ):
        """Constructor for UNet class.
        Args:
            ndim: Dimensionality of input data (2 or 3).
            activation_fn: Activation function to use.
            activation_kwargs: Additional arguments for the activation function.
            dropout: Dropout probability.
            depth: Depth of the U-net.
            n_in_channels: Number of input channels.
            out_channels: Number of output channels.
            mult_chan: Factor to determine number of output channels."""
        super().__init__()
        mult_chan = mult_chan
        self.depth = depth

        if ndim == 2:
            ConvNd = nn.Conv2d
            DropoutNd = nn.Dropout2d
        elif ndim == 3:
            ConvNd = nn.Conv3d
            DropoutNd = nn.Dropout3d

        self.net_recurse = NetRecurse(
            n_in_channels=n_in_channels,
            mult_chan=mult_chan,
            depth=depth,
            ndim=ndim,
            activation_fn=activation_fn,
            activation_kwargs=activation_kwargs,
            dropout=dropout,
        )
        self.conv_out = ConvNd(mult_chan, out_channels,
                               kernel_size=3, padding=1)

    def forward(self, x):
        x_rec = self.net_recurse(x)
        return self.conv_out(x_rec)


class NetRecurse(nn.Module):
    """Recursive definition of U-network."""
    def __init__(
        self,
        n_in_channels,
        mult_chan=2,
        depth=0,
        ndim: int = 2,
        activation_fn: Callable = nn.ReLU,
        activation_kwargs: dict = dict(inplace=True),
        dropout: float = 0,
    ):
        """Class for recursive definition of U-network.p

        Parameters
        ----------
        in_channels
            Number of channels for input.
        mult_chan
            Factor to determine number of output channels
        depth
            If 0, this subnet will only be convolutions that double the channel
            count.

        """
        super().__init__()
        self.depth = depth
        n_out_channels = n_in_channels * mult_chan
        if ndim == 2:
            ConvNd = nn.Conv2d
            InstanceNormNd = nn.InstanceNorm2d
            ConvTransposeNd = nn.ConvTranspose2d
            DropoutNd = nn.Dropout2d
        elif ndim == 3:
            ConvNd = nn.Conv3d
            InstanceNormNd = nn.InstanceNorm3d
            ConvTransposeNd = nn.ConvTranspose3d
            DropoutNd = nn.Dropout3d

        self.sub_2conv_more = SubNet2Conv(
            n_in_channels,
            n_out_channels,
            ndim=ndim,
            activation_fn=activation_fn,
            activation_kwargs=activation_kwargs,
            dropout=dropout,
        )

        if depth > 0:
            self.sub_2conv_less = SubNet2Conv(
                2 * n_out_channels,
                n_out_channels,
                ndim=ndim,
                activation_fn=activation_fn,
                activation_kwargs=activation_kwargs,
                dropout=dropout,
            )
            self.conv_down = ConvNd(
                n_out_channels, n_out_channels, kernel_size=2, stride=2)
            self.bn0 = InstanceNormNd(n_out_channels, affine=True)
            self.relu0 = activation_fn(*activation_kwargs)
            self.convt = ConvTransposeNd(
                2 * n_out_channels, n_out_channels, kernel_size=2, stride=2
            )
            self.bn1 = InstanceNormNd(n_out_channels, affine=True)
            self.relu1 = activation_fn(*activation_kwargs)
            self.sub_u = NetRecurse(
                n_out_channels,
                mult_chan=2,
                depth=(depth - 1),
                ndim=ndim,
                activation_fn=activation_fn,
                activation_kwargs=activation_kwargs,
                dropout=dropout,
            )

    def forward(self, x):
        if self.depth == 0:
            return self.sub_2conv_more(x)
        else:  # depth > 0
            x_2conv_more = self.sub_2conv_more(x)
            x_conv_down = self.conv_down(x_2conv_more)
            x_bn0 = self.bn0(x_conv_down)
            x_relu0 = self.relu0(x_bn0)
            x_sub_u = self.sub_u(x_relu0)
            x_convt = self.convt(x_sub_u)
            x_bn1 = self.bn1(x_convt)
            x_relu1 = self.relu1(x_bn1)
            x_cat = torch.cat((x_2conv_more, x_relu1), 1)  # concatenate
            x_2conv_less = self.sub_2conv_less(x_cat)
        return x_2conv_less


class SubNet2Conv(nn.Module):
    """Subnetwork with two convolutional layers."""
    def __init__(
        self,
        n_in,
        n_out,
        ndim: int = 2,
        activation_fn: Callable = nn.ReLU,
        activation_kwargs: dict = dict(inplace=True),
        dropout: float = 0,
    ):
        """Constructor for SubNet2Conv class.
        Args:
            n_in: Number of input channels.
            n_out: Number of output channels.
            ndim: Dimensionality of input data (2 or 3).
            activation_fn: Activation function to use.
            activation_kwargs: Additional arguments for the activation function.
            dropout: Dropout probability."""
        super().__init__()

        if ndim == 2:
            ConvNd = nn.Conv2d
            InstanceNormNd = nn.InstanceNorm2d
            DropoutNd = nn.Dropout2d
        elif ndim == 3:
            ConvNd = nn.Conv3d
            InstanceNormNd = nn.InstanceNorm3d
            DropoutNd = nn.Dropout3d

        self.conv1 = ConvNd(n_in, n_out, kernel_size=3, padding=1)
        self.bn1 = InstanceNormNd(n_out, affine=True)
        self.relu1 = activation_fn(*activation_kwargs)
        self.conv2 = ConvNd(n_out, n_out, kernel_size=3, padding=1)
        self.bn2 = InstanceNormNd(n_out, affine=True)
        self.relu2 = activation_fn(*activation_kwargs)
        self.dropout_layer = DropoutNd(p=dropout)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        if self.dropout_layer.p > 0:
            x = self.dropout_layer(x)
        return x

In [2]:
from torch import nn
from typing import Callable


class Discriminator(nn.Module):
    """Defines a PatchGAN critic (discriminator of Wasserstein GANs)
    The PatchGAN critic is a convolutional neural network that
    operates on small patches of the input image. It is used in Conditional GANs and
    WGANs to assign a realisticness score to each patch of the input image.
    Our implementation has a receptive field of 54x54 pixels."""

    def __init__(
            self,
            ndim: int = 2,
            input_nc: int = 2,  # number of channels in input images: source+target
            activation_fn: Callable = nn.LeakyReLU,
            activation_kwargs: dict = dict(negative_slope=0.05, inplace=True),
            ndf: int = 128,  # number of filters in the last conv layer
            norm_layer: Callable = nn.InstanceNorm3d,
    ):
        n_layers = 3  # number of conv layers in the discriminator
        if ndim == 2:
            ConvNd = nn.Conv2d
            norm_layer = nn.InstanceNorm2d
        elif ndim == 3:
            ConvNd = nn.Conv3d
            norm_layer = nn.InstanceNorm3d

        super().__init__()
        kw = 4  # kernel size
        padw = 1  # padding
        sequence = [ConvNd(input_nc, ndf, kernel_size=kw, stride=2,
                           padding=padw), activation_fn(*activation_kwargs)]
        nf_mult = 1  # number of filters
        nf_mult_prev = 1  # number of filters in previous layer
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)  # number of filters
            sequence += [
                ConvNd(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw,
                       stride=2, padding=padw),  # bias=use_bias),
                norm_layer(ndf * nf_mult, affine=True),
                activation_fn(*activation_kwargs)
            ]

        kw = 3
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)  # number of filters
        sequence += [
            ConvNd(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw,
                   stride=1, padding=padw),  # bias=use_bias),
            norm_layer(ndf * nf_mult, affine=True),
            activation_fn(*activation_kwargs)
        ]

        # output 1 channel prediction map
        sequence += [ConvNd(ndf * nf_mult, 1, kernel_size=kw,
                            stride=1, padding=padw)]
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)

In [5]:
import torch
# from torchvision.transforms import ToTensor
import lightning as L
from matplotlib.backends.backend_agg import FigureCanvasAgg
from PIL import Image
import io

from ..utils.losses import GenLoss
from ..utils.visualize import display_batch, hstack_figures
from ..utils.metrics import classification_metrics


class WGANGP(L.LightningModule):

    def __init__(self, Gen, Disc, config):
        super().__init__()
        self.automatic_optimization = False
        self.L1Loss = torch.nn.L1Loss(reduction="none")

        self.target_channels = config['target_channels']
        self.len_val_loader = config['len_val_loader']
        self.use_classification_metric = config['use_classification_metric']
        self.adversarial_training = config['adversarial_training']

        if self.use_classification_metric:  # compute classification metric
            self.classifier = config['classifier']
            self.classification_loader = config['classification_loader']
            self.classifier.freeze()

        # networks
        self.generator = Gen
        self.gen_criterion = GenLoss(loss='l1')

        if self.adversarial_training:
            self.discriminator = Disc

        # val step variables
        self.fig_list = []
        self.counter = 0

        # hyperparameters
        self.lr_g = config['lr_g']  # generator
        self.lr_d = config['lr_d']  # discriminator (critic)

    def forward(self, z):
        return self.generator(z)

    def compute_gp(self, real_data, fake_data):
        batch_size = real_data.size(0)
        # Sample Epsilon from uniform distribution
        eps = torch.rand(batch_size, 1, 1, 1, 1).to(real_data.device)
        eps = eps.expand_as(real_data)

        # Interpolation between real data and fake data.
        interpolation = (eps * real_data + (1 - eps) *
                         fake_data).requires_grad_(True)

        # get logits for interpolated images
        interp_logits = self.discriminator(interpolation)
        grad_outputs = torch.ones_like(interp_logits)

        # Compute Gradients
        gradients = torch.autograd.grad(
            outputs=interp_logits,
            inputs=interpolation,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True,
        )[0]

        # Compute and return Gradient Norm
        gradients = gradients.view(batch_size, -1)
        grad_norm = gradients.norm(2, 1)
        return torch.mean((grad_norm - 1) ** 2)

    def training_step_wgan(self, batch, batch_idx):
        n_critic = 5  # number of training steps for discriminator per iter
        opt_g, opt_d = self.optimizers()

        x, y_real = batch

        mask = ~(y_real.isnan())  # define a mask representing the elements where y_real has values (not NaNs)

        y_real = torch.nan_to_num(y_real, nan=0.0)  # substitute NaN values in y_real with zeros

        y_fake = self.generator(x)

        y_fake_masked = y_fake*mask  # select the elements of y_fake that have a valid corresponding element in y_real

        ##########################
        # Optimize Discriminator #
        ##########################
        real_concat_with_input = torch.cat((x, y_real), 1) #the critic has to determine the realisticness of source+target 
        real_out = self.discriminator(real_concat_with_input).mean() #output of the critic on real input

        fake_concat_with_input = torch.cat((x, y_fake_masked.detach()), 1)
        fake_out = self.discriminator(fake_concat_with_input).mean() #output of the critic on generated input

        lambda_gp = 10
        gp = self.compute_gp(real_concat_with_input, fake_concat_with_input)

        was_loss = fake_out - real_out + lambda_gp * gp
        was_loss.create_graph = True  # enable backprop for gradient penalty

        opt_d.zero_grad()
        self.manual_backward(was_loss, retain_graph=True)
        opt_d.step()

        ######################
        # Optimize Generator #
        ######################
        if batch_idx % n_critic == 0:  # update generator every n_critic steps
            gen_fake_concat_with_input = torch.cat((x, y_fake_masked), 1)
            disc_fake_out = self.discriminator(gen_fake_concat_with_input)
            adv_weight = 0.05
            epoch = self.current_epoch
            g_loss, l1_loss_train, adv_loss = self.gen_criterion(
                disc_fake_out, y_fake_masked, y_real, epoch=epoch, mask=mask, adv_weight=adv_weight)  # compute generator loss

            opt_g.zero_grad()
            self.manual_backward(g_loss)
            opt_g.step()

            self.log_dict({"g_loss": g_loss, "was_loss": was_loss,
                           "gp": gp, "adv_loss": adv_loss,
                           "train_L1": l1_loss_train,
                           }, prog_bar=True)

    def training_step_unet(self, batch, batch_idx):  # without adv loss
        opt_g = self.optimizers()

        x, y_real = batch

        mask = ~(y_real.isnan())  # define a mask representing the elements where y_real has values (not NaNs)

        y_real = torch.nan_to_num(y_real, nan=0.0)  # substitute NaN values in y_real with zeros

        y_fake = self.generator(x)

        y_fake_masked = y_fake*mask  # select the elements of y_fake that have a valid corresponding element in y_real

        ######################
        # Optimize Generator #
        ######################

        # the pixel wise loss would be artificially low because the non existing channels in y_real are set to 0 also in y_fake_masked
        # reduce the L1loss by computing the mean only on the valid channels       
        g_loss = self.L1Loss(y_fake_masked, y_real)[mask].mean()
        opt_g.zero_grad()
        self.manual_backward(g_loss) #compute gradients
        opt_g.step() #update weights of the generator

        if batch_idx % 5 == 0:
            self.log_dict({"g_loss": g_loss,
                           "train_L1": g_loss,
                           }, prog_bar=True)

    def training_step(self, batch, batch_idx):
        if self.adversarial_training:
            self.training_step_wgan(batch, batch_idx)
        else:
            self.training_step_unet(batch, batch_idx)

    def configure_optimizers(self):
        opt_g = torch.optim.RMSprop(
            params=self.generator.parameters(), lr=self.lr_g)
        if self.adversarial_training:
            opt_d = torch.optim.RMSprop(
                params=self.discriminator.parameters(), lr=self.lr_d)
            return opt_g, opt_d

        return opt_g

    def on_train_epoch_start(self):
        torch.cuda.synchronize()

    def on_validation_epoch_end(self):
        if not self.use_classification_metric:
            return
        f1_macro, f1_micro, f1_weighted = 0, 0, 0
        for batch in self.classification_loader:
            f1_macro_batch, f1_micro_batch, f1_weighted_batch = classification_metrics(
                self.generator, self.classifier, batch)
            f1_macro += f1_macro_batch
            f1_micro += f1_micro_batch
            f1_weighted += f1_weighted_batch
        n_samples = len(self.classification_loader)
        f1_macro /= n_samples
        f1_micro /= n_samples
        f1_weighted /= n_samples
        self.log("train_f1_macro", f1_macro, on_step=False,
                 on_epoch=True, prog_bar=True)
        self.log("train_f1_micro", f1_micro, on_step=False,
                 on_epoch=True, prog_bar=True)
        self.log("train_f1_weighted", f1_weighted,
                 on_step=False, on_epoch=True, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        x, y_real = batch
        with torch.no_grad():
            y_fake = self.generator(x)
        mask = ~(y_real.isnan())

        l1loss = self.L1Loss(y_fake, y_real)
        l1loss_mean = l1loss[mask].mean()
        self.log("val_L1", l1loss_mean, on_step=False,
                 on_epoch=True, prog_bar=True, sync_dist=True)

        if self.use_classification_metric:
            f1_macro, f1_micro, f1_weighted = classification_metrics(
                self.generator, self.classifier, batch, prediction=y_fake)
            self.log("f1_macro", f1_macro, on_step=False,
                     on_epoch=True, prog_bar=True)
            self.log("f1_micro", f1_micro, on_step=False,
                     on_epoch=True, prog_bar=True)
            self.log("f1_weighted", f1_weighted, on_step=False,
                     on_epoch=True, prog_bar=True)

        for i, ch in enumerate(self.target_channels):
            # select the channel and compute the mean only for that channel, for the samples in the batch that have that channel populated
            channel_loss = l1loss[:, i, ...][mask[:, i, ...]].mean()
            if not channel_loss.isnan():
                self.log(f"L1_{ch}", channel_loss, sync_dist=True)

        if self.counter > 5:
            tensorboard = self.logger.experiment
            # stack horizontally the figures in self.fig_list
            image = hstack_figures(self.fig_list)
            # tensorboard.add_image('val_fig', ToTensor()(
            #     image), global_step=self.global_step)
            self.fig_list = []
            self.counter = 0

        elif batch_idx % (self.len_val_loader//6) == 0:
            fig = display_batch(x, y_real, pred=y_fake, target_channel_names=self.target_channels,
                                limit_images=16, show3d='middle', batch=self.counter)
            canvas = FigureCanvasAgg(fig)
            # Get the renderer's buffer and string it into a PNG
            buf = io.BytesIO()
            canvas.print_png(buf)
            img = Image.open(buf)
            self.fig_list.append(img)
            self.counter += 1

        return l1loss

AttributeError: partially initialized module 'torchvision' from '/home/lorenzovenieri/aisb/.venv/lib64/python3.13/site-packages/torchvision/__init__.py' has no attribute 'extension' (most likely due to a circular import)