# Fairness-Constrained Policy Optimization (FCPO)

**Implementation of long-term fairness in RL**

In [None]:
!pip install torch numpy matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')

In [None]:
class Env:
    def __init__(self):
        self.n_s, self.n_a, self.n_g, self.gamma = 5, 2, 2, 0.95
        self.init = {0: [0.4, 0.3, 0.2, 0.1, 0.0], 1: [0.1, 0.2, 0.3, 0.3, 0.1]}
    def reset(self, g=None):
        self.g = np.random.randint(2) if g is None else g
        self.s = np.random.choice(5, p=self.init[self.g])
        return self._obs()
    def _obs(self):
        o = np.zeros(7)
        o[self.s] = 1; o[5+self.g] = 1
        return o
    def step(self, a):
        r = (self.s/4)*2-0.5 if a else 0
        if a: self.s = min(self.s+1, 4) if np.random.rand()<0.7 else max(self.s-1, 0)
        elif np.random.rand()<0.5: self.s = max(self.s-1, 0)
        return self._obs(), r, False, self.g

env = Env()
print('Env ready')

In [None]:
class Pi(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(7, 64), nn.Tanh(), nn.Linear(64, 2))
    def forward(self, x):
        return F.softmax(self.net(x), -1)

class V(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(7, 64), nn.Tanh(), nn.Linear(64, 1))
    def forward(self, x):
        return self.net(x).squeeze()

In [None]:
class FCPO:
    def __init__(self, env, eps=0.15):
        self.env, self.eps, self.gam = env, eps, env.gamma
        self.pi = Pi().to(device)
        self.v = V().to(device)
        self.vg = nn.ModuleList([V().to(device) for _ in range(2)])
        self.opt_pi = torch.optim.Adam(self.pi.parameters(), 3e-3)
        self.opt_v = torch.optim.Adam(self.v.parameters(), 3e-3)
        self.opt_vg = torch.optim.Adam(self.vg.parameters(), 3e-3)
        self.lam = torch.zeros(2).to(device)
    
    def collect(self, n=20):
        obs, act, rew, grp = [], [], [], []
        for _ in range(n):
            o = self.env.reset()
            for _ in range(100):
                with torch.no_grad():
                    p = self.pi(torch.tensor(o, dtype=torch.float32).to(device))
                    a = torch.multinomial(p, 1).item()
                o2, r, _, g = self.env.step(a)
                obs.append(o); act.append(a); rew.append(r); grp.append(g)
                o = o2
        return obs, act, rew, grp
    
    def rets(self, rews):
        G, gs = 0, []
        for r in reversed(rews):
            G = r + self.gam * G
            gs.insert(0, G)
        return torch.tensor(gs, dtype=torch.float32).to(device)
    
    def train(self):
        obs, act, rew, grp = self.collect()
        O = torch.tensor(obs, dtype=torch.float32).to(device)
        A = torch.tensor(act, dtype=torch.long).to(device)
        G = torch.tensor(grp, dtype=torch.long).to(device)
        R = self.rets(rew)
        
        # V update
        v_loss = ((self.v(O) - R)**2).mean()
        self.opt_v.zero_grad(); v_loss.backward(); self.opt_v.step()
        
        # Vg update
        for g in range(2):
            m = G == g
            if m.sum() > 0:
                vg_loss = ((self.vg[g](O[m]) - R[m])**2).mean()
                self.opt_vg.zero_grad(); vg_loss.backward(); self.opt_vg.step()
        
        # Compute advantages (detached)
        with torch.no_grad():
            adv = R - self.v(O)
            fadj = torch.zeros_like(adv)
            for g in range(2):
                m = G == g
                if m.sum() > 0:
                    fadj[m] = (R[m] - self.vg[g](O[m])) * self.lam[g]
            adv_f = adv - fadj
        
        # Policy update - use gather for proper gradient flow
        probs = self.pi(O)
        log_p = torch.log(probs.gather(1, A.unsqueeze(1)).squeeze() + 1e-10)
        loss_pi = -(log_p * adv_f).mean()
        self.opt_pi.zero_grad(); loss_pi.backward(); self.opt_pi.step()
        
        # Dual update
        with torch.no_grad():
            g_rets = [R[G==g].mean().item() if (G==g).sum()>0 else 0 for g in range(2)]
            disp = max(g_rets) - min(g_rets)
            for g in range(2):
                self.lam[g] = max(0, self.lam[g] + 0.01*(g_rets[g]-np.mean(g_rets)))
        
        return {'d': disp, 'gr': g_rets, 'lam': self.lam.cpu().tolist()}

In [None]:
agent = FCPO(env)
h = defaultdict(list)

for i in range(100):
    m = agent.train()
    for k,v in m.items(): h[k].append(v)
    if i%10==0: print(f"{i}: D={m['d']:.3f}, G={[f'{x:.2f}' for x in m['gr']]}")
print('Done')

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(10, 6))
ax[0,0].plot(h['d']); ax[0,0].axhline(agent.eps, color='r', ls='--'); ax[0,0].set_title('Disparity')
for g in range(2): ax[0,1].plot([x[g] for x in h['gr']], label=f'G{g}')
ax[0,1].legend(); ax[0,1].set_title('Returns')
for g in range(2): ax[1,0].plot([x[g] for x in h['lam']], label=f'Î»{g}')
ax[1,0].legend(); ax[1,0].set_title('Dual Vars')
ax[1,1].plot(h['d']); ax[1,1].set_title('Disparity')
plt.tight_layout(); plt.show()
print(f"Final: {h['d'][-1]:.3f} <= {agent.eps}")