In [None]:
#شبکه های عصبی بیزی

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
from pyro.infer import SVI, Trace_ELBO
import pyro.optim as optim


class BayesianConv3d(PyroModule):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        # Use PyroSample for weights and biases
        self.weight = PyroSample(dist.Normal(0., 0.1).expand([out_channels, in_channels, kernel_size, kernel_size, kernel_size]).to_event(5))
        self.bias = PyroSample(dist.Normal(0., 0.1).expand([out_channels]).to_event(1))

    def forward(self, x):
        weight = self.weight
        bias = self.bias
        return F.conv3d(x, weight, bias, stride=self.stride, padding=self.padding)

class BayesianCNN3D(PyroModule):
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.conv1 = BayesianConv3d(in_channels, 32, kernel_size=3, padding=1)
        self.conv2 = BayesianConv3d(32, 64, kernel_size=3, padding=1)
        self.conv3 = BayesianConv3d(64, 128, kernel_size=3, padding=1)
        self.conv4 = BayesianConv3d(128, num_classes, kernel_size=1)

        self.pool = nn.MaxPool3d(2)

    def forward(self, x, y=None):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.conv4(x)
        return x

    def model(self, x, y=None):
        logits = self.forward(x)
        with pyro.plate("data", x.shape[0]):
            pyro.sample("obs", dist.Categorical(logits=logits), obs=y)
        return logits

    def guide(self, x, y=None):
        # This is needed for SVI
        pass

def load_and_preprocess_data(t1,t1ce,t2,flair,seg):
    data = []
    labels = []
    image = np.stack([t1, t1ce, t2, flair], axis=0)

    data.append(image)
    labels.append(seg)

    return np.array(data), np.array(labels)

def prepare_data(data,seg):
    X = data
    y = seg
    X_tensor = torch.from_numpy(X).float()
    y_tensor = torch.from_numpy(y).long()
    return torch.utils.data.TensorDataset(X_tensor, y_tensor)


def dice_loss(pred, target):
    smooth = 1e-5
    intersection = (pred * target).sum(dim=(2,3,4))
    union = pred.sum(dim=(2,3,4)) + target.sum(dim=(2,3,4))
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - dice.mean()


def train(model, train_loader, val_loader, num_epochs, learning_rate, device):
    optimizer = optim.Adam({"lr": learning_rate})
    elbo = Trace_ELBO()
    svi = SVI(model.model, model.guide, optimizer, loss=elbo)

    for epoch in range(num_epochs):
        train_loss = 0.0
        for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            data, target = data.to(device), target.to(device)
            loss = svi.step(data, target)
            train_loss += loss

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                loss = elbo.loss(model.model, model.guide, data, target)
                val_loss += loss

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}")
