In [6]:
from copy import deepcopy
import torch
import pytorch_lightning as pl

class BaseSSL(pl.LightningModule):
    def __init__(self, backbone, config=None):
        super().__init__()
        self.backbone = backbone
        self.projection_head = None
        self.config = config

    def forward(self, x):
        features = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(features)
        return {"features": features, "projection": z}

    def get_params(self):
        # Your existing implementation
        pass

    def configure_optimizers(self):
        # Your existing implementation
        pass

class SimCLRSSL(BaseSSL):
    def __init__(self, backbone, config):
        super().__init__(backbone, config)  # Calls BaseSSL.__init__()
        self.criterion = None

    def training_step(self, batch, batch_idx):
        x = batch[0]  # Augmented views (x1, x2)
        z = self(x)["projection"]  # Calls BaseSSL.forward()
        loss = self.criterion(z)
        return loss
    
class MoCoSSL:
    def __init__(self, backbone, config):
        super().__init__(backbone, config)  # Calls BaseSSL.__init__()
        self.momentum_encoder = deepcopy(backbone)
        self.criterion = None

    def training_step(self, batch, batch_idx):
        x = batch[0]
        z_online = self(x)["projection"]  # Online network
        with torch.no_grad():
            z_momentum = self.momentum_encoder(x)  # Momentum network
        loss = self.criterion(z_online, z_momentum)
        return loss

    def on_train_batch_end(self):
        self._update_momentum_encoder()

    def _update_momentum_encoder(self):
        tau = self.final_tau - (self.final_tau - self.base_tau) * (
            (1 + math.cos(math.pi * self.global_step / self.trainer.max_steps)) / 2
        )
        for p_online, p_momentum in zip(
            self.backbone.parameters(),
            self.momentum_encoder.parameters()
        ):
            p_momentum.data = tau * p_momentum.data + (1 - tau) * p_online.data

class EWC:
    def __init__(self, ewc_lambda=0.4, mode="separate", decay_factor=None):
        self.ewc_lambda = ewc_lambda
        self.mode = mode
        self.decay_factor = decay_factor
        self.saved_params = {}  # Stores {exp_id: {param_name: tensor}}
        self.importances = {}   # Fisher information matrices

    def before_backward(self, strategy, **kwargs):
        if not self.saved_params:  # Skip if no previous tasks
            return
            
        penalty = 0
        for exp_id, params in self.saved_params.items():
            for name, param in strategy.model.named_parameters():
                if name in params:
                    penalty += (self.importances[exp_id][name] * 
                              (param - params[name]).pow(2)).sum()
        
        strategy.loss += self.ewc_lambda * penalty

    def after_training_exp(self, strategy, **kwargs):
        self._compute_fisher_matrix(strategy)
        self._save_params(strategy.model)

    def _compute_fisher_matrix(self, strategy):
        # Your existing EWC Fisher computation logic
        pass

    def _save_params(self, model):
        # Your existing parameter saving logic
        pass

In [None]:
simclr = SimCLRSSL(
    backbone="resnet18", 
    config={"batch_size": 32}
)



In [10]:
class EWCModel(SimCLRSSL, EWC):
    def __init__(self, backbone, config):
        super().__init__(backbone, config)
        EWC.__init__(self, config["ewc_lambda"], config["ewc_mode"])

model = EWCModel(
    backbone="resnet18", 
    config={
        "batch_size": 32,
        "ewc_lambda": 0.4,
        "ewc_mode": "separate"
    }
)

In [12]:
model.before_backward

<bound method EWC.before_backward of EWCModel()>