<a href="https://colab.research.google.com/github/tsanoop887-hash/AIF360/blob/main/fusion_ai_model_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import math
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast ,GradScaler
import numpy as np

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

termal diffusivity

In [3]:
alpha= 0.01
t_final=1.0
domain =[0.0,1.0]

utilities

In [4]:
def meshgrid2d(n):
  xs=np.linspace(domain[0],domain[1],n)
  ys=np.linspace(domain[0],domain[1],n)
  X , Y=np.meshgrid(xs,ys, indexing='xy')
  return X, Y

Fourier Feature mapping


In [5]:
class FFMapping(nn.Module):
  def __init__(self, in_dim=3 ,mapping_size=64 ,scale=10.0):
    super().__init__()
    B=torch.randn(in_dim,mapping_size)*scale
    self.register_buffer('B',B)
  def forward(self,x):

    x_proj=2*math.pi*x @ self.B
    return torch.cat([torch.sin(x_proj),torch.cos(x_proj)],dim=-1)


In [6]:
class PINN(nn.Module):
  def __init__(self,in_dim=3,hidden=256,layers=6,ff_size=64):
    super().__init__()
    self.ff=FFMapping(in_dim=in_dim,mapping_size=ff_size,scale=5.0)
    input_dim=ff_size*2
    seq=[]
    seq.append(nn.Linear(input_dim,hidden))
    seq.append(nn.SiLU())
    for _ in range(layers-1):
      seq.append(nn.Linear(hidden,hidden))
      seq.append(nn.SILU())
      seq.append(nn.Linear(hidden,1))
      self.net=nn.Sequential(*seq)
      self.__init__weights()
  def __init__weights(self):
    for m in self.net:
      if isinstance(m,nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.zeros_(m.bias)
  def forward(self,xyt):

    z=self.ff(xyt)
    return self.net(z)

In [7]:
class SpectralConv2d(nn.Module):
  def __init__(self,in_channels ,out_channels,modes1, modes2):
    super().__init__()
    self.in_channels=in_channels
    self.out_channels=out_channels
    self.modes1=modes1
    self.modes2=modes2
    self.scale=1/(in_channels*out_channels)
    self.weights=nn.Parameter(self.scale*torch.randn(in_channels,out_channels,modes1,modes2))
def compl_mul2d(self,input,weights):

  return torch.einsum('bixy,ioxy->boxy',input, weights)
def forward(self,x):

  batchsize = x.shape[0]

  x_ft = torch.fft.rfft2(x,norm='ortho')
  out_ft = torch.zeros(batchsize,self.out_channels,x.size(-2),x.size(-1)//2+1,dtype=torch.cfloat,device=x.device)
  mx= min(self.modes1,x_ft.size(-2))
  my= min(self.modes2,x_ft.size(-1))
  weight=torch.view_as_complex(self.weights)
  out_ft[:,:,:mx,:my]=torch.einsum('bixy,ioxy->boxy',x_ft[:,:,:mx,:my], weight)
  x= torch.fft.irfft2(out_ft,s=(x.size(-2),x.size(-1)),norm='ortho')
  return x

In [8]:
class FNO2d(nn.Module):
    def __init__(self, modes1, modes2, width):
        super().__init__()
        self.modes1 = modes1
        self.modes2 = modes2
        self.width = width
        self.fc0 = nn.Linear(3, self.width) # input: (x,y, forcing) -> lift to width
        self.conv0 = SpectralConv2d(self.width, self.width, modes1, modes2)
        self.conv1 = SpectralConv2d(self.width, self.width, modes1, modes2)
        self.w0 = nn.Conv2d(self.width, self.width, 1)
        self.w1 = nn.Conv2d(self.width, self.width, 1)
        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        # x shape: (batch, nx, ny, 3) -> reshape to (batch, 3, nx, ny)
        batch, nx, ny, c = x.shape
        x = x.permute(0,3,1,2) # (batch, c, nx, ny)
        x = x.reshape(batch, c, nx, ny)

        # lift
        x = x.permute(0,2,3,1).reshape(-1, c)
        x = self.fc0(x).reshape(batch, nx, ny, self.width).permute(0,3,1,2)

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = nn.functional.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2

        x = x.permute(0,2,3,1).reshape(-1, self.width)
        x = self.fc1(x)
        x = nn.functional.gelu(x)
        x = self.fc2(x)
        x = x.reshape(batch, nx, ny, 1)
        return x

In [9]:
def pde_residual(model, xyt):
  xyt = xyt.clone().detach().requires_grad_(True)
  T = model(xyt)
  grads = torch.autograd.grad(T, xyt, grad_outputs=torch.ones_like(T), create_graph=True)[0]
  T_x = grads[:,0:1]
  T_y = grads[:,1:2]
  T_t = grads[:,2:3]
  T_xx = torch.autograd.grad(T_x, xyt, grad_outputs=torch.ones_like(T_x), create_graph=True)[0][:,0:1]
  T_yy = torch.autograd.grad(T_y, xyt, grad_outputs=torch.ones_like(T_y), create_graph=True)[0][:,1:2]
  res = T_t - alpha*(T_xx + T_yy)
  return res

In [10]:
def generate_synthetic_field(nx=64, ny=64, samples=100):
  # generate random initial conditions and simulate via simple spectral method (placeholder)
  X, Y = meshgrid2d(nx)
  data_in = []
  data_out = []
  for s in range(samples):
    # random gaussian bumps as forcing
    centers = np.random.rand(4,2)
    field0 = np.zeros_like(X)
    for c in centers:
      field0 += np.exp(-50*((X-c[0])**2 + (Y-c[1])**2))
    # make target by smoothing (proxy for PDE evolution)
    target = field0.copy()
    for _ in range(10):
      target = (np.roll(target,1,axis=0)+np.roll(target,-1,axis=0)+np.roll(target,1,axis=1)+np.roll(target,-1,axis=1))/4.0
    data_in.append(field0[...,None])
    data_out.append(target[...,None])
  data_in = np.stack(data_in, axis=0).astype(np.float32)
  data_out = np.stack(data_out, axis=0).astype(np.float32)
  return data_in, data_out

In [13]:
def train():
  # instantiate models
  pinn = PINN().to(device)
  fno = FNO2d(modes1=12, modes2=12, width=32).to(device)


  opt_pinn = optim.Adam(pinn.parameters(), lr=1e-3)
  opt_fno = optim.Adam(fno.parameters(), lr=1e-3)


  scaler = GradScaler()


  # synthetic dataset
  nx = ny = 64
  X, Y = meshgrid2d(nx)
  data_in, data_out = generate_synthetic_field(nx=nx, ny=ny, samples=200)


  # convert to torch
  data_in_t = torch.from_numpy(data_in).to(device) # (S, nx, ny, 1)
  data_out_t = torch.from_numpy(data_out).to(device)


  epochs = 2000
  for ep in range(epochs):
    # PINN training
    opt_pinn.zero_grad()
    # Sample points for PINN
    x = torch.rand(1000, 1, device=device) * (domain[1] - domain[0]) + domain[0]
    y = torch.rand(1000, 1, device=device) * (domain[1] - domain[0]) + domain[0]
    t = torch.rand(1000, 1, device=device) * t_final
    xyt = torch.cat([x, y, t], dim=1)

    with autocast():
      res = pde_residual(pinn, xyt)
      loss_pinn = torch.mean(res**2)

    scaler.scale(loss_pinn).backward()
    scaler.step(opt_pinn)

    # FNO training
    opt_fno.zero_grad()
    # Randomly select a sample for FNO
    idx = np.random.randint(data_in_t.shape[0])
    fno_in = torch.cat([X[None,:,:,None], Y[None,:,:,None], data_in_t[idx:idx+1,...]], dim=-1)

    with autocast():
      fno_out = fno(fno_in)
      loss_fno = torch.mean((fno_out - data_out_t[idx:idx+1,...])**2)

    scaler.scale(loss_fno).backward()
    scaler.step(opt_fno)
    scaler.update()

    if ep % 100 == 0:
      print(f'Epoch {ep}, PINN Loss: {loss_pinn.item():.4f}, FNO Loss: {loss_fno.item():.4f}')