@@ -1330,58 +1330,46 @@ def FDIV32ri_prec :
1330
1330
// FMA
1331
1331
//
1332
1332
1333
- multiclass FMA<string OpcStr, RegisterClass RC, Operand ImmCls, Predicate Pred > {
1333
+ multiclass FMA<string OpcStr, RegTyInfo t, list<Predicate> Preds = [] > {
1334
1334
defvar asmstr = OpcStr # " \t$dst, $a, $b, $c;";
1335
- def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
1335
+ def rrr : NVPTXInst<(outs t. RC:$dst), (ins t. RC:$a, t. RC:$b, t. RC:$c),
1336
1336
asmstr,
1337
- [(set RC:$dst, (fma RC:$a, RC:$b, RC:$c))]>,
1338
- Requires<[Pred]>;
1339
- def rri : NVPTXInst<(outs RC:$dst),
1340
- (ins RC:$a, RC:$b, ImmCls:$c),
1341
- asmstr,
1342
- [(set RC:$dst, (fma RC:$a, RC:$b, fpimm:$c))]>,
1343
- Requires<[Pred]>;
1344
- def rir : NVPTXInst<(outs RC:$dst),
1345
- (ins RC:$a, ImmCls:$b, RC:$c),
1346
- asmstr,
1347
- [(set RC:$dst, (fma RC:$a, fpimm:$b, RC:$c))]>,
1348
- Requires<[Pred]>;
1349
- def rii : NVPTXInst<(outs RC:$dst),
1350
- (ins RC:$a, ImmCls:$b, ImmCls:$c),
1351
- asmstr,
1352
- [(set RC:$dst, (fma RC:$a, fpimm:$b, fpimm:$c))]>,
1353
- Requires<[Pred]>;
1354
- def iir : NVPTXInst<(outs RC:$dst),
1355
- (ins ImmCls:$a, ImmCls:$b, RC:$c),
1356
- asmstr,
1357
- [(set RC:$dst, (fma fpimm:$a, fpimm:$b, RC:$c))]>,
1358
- Requires<[Pred]>;
1359
-
1360
- }
1361
-
1362
- multiclass FMA_F16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> {
1363
- def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
1364
- !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
1365
- [(set T:$dst, (fma T:$a, T:$b, T:$c))]>,
1366
- Requires<[useFP16Math, Pred]>;
1367
- }
1368
-
1369
- multiclass FMA_BF16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> {
1370
- def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
1371
- !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
1372
- [(set T:$dst, (fma T:$a, T:$b, T:$c))]>,
1373
- Requires<[hasBF16Math, Pred]>;
1337
+ [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, t.Ty:$c))]>,
1338
+ Requires<Preds>;
1339
+
1340
+ if t.SupportsImm then {
1341
+ def rri : NVPTXInst<(outs t.RC:$dst),
1342
+ (ins t.RC:$a, t.RC:$b, t.Imm:$c),
1343
+ asmstr,
1344
+ [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, fpimm:$c))]>,
1345
+ Requires<Preds>;
1346
+ def rir : NVPTXInst<(outs t.RC:$dst),
1347
+ (ins t.RC:$a, t.Imm:$b, t.RC:$c),
1348
+ asmstr,
1349
+ [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, t.Ty:$c))]>,
1350
+ Requires<Preds>;
1351
+ def rii : NVPTXInst<(outs t.RC:$dst),
1352
+ (ins t.RC:$a, t.Imm:$b, t.Imm:$c),
1353
+ asmstr,
1354
+ [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, fpimm:$c))]>,
1355
+ Requires<Preds>;
1356
+ def iir : NVPTXInst<(outs t.RC:$dst),
1357
+ (ins t.Imm:$a, t.Imm:$b, t.RC:$c),
1358
+ asmstr,
1359
+ [(set t.Ty:$dst, (fma fpimm:$a, fpimm:$b, t.Ty:$c))]>,
1360
+ Requires<Preds>;
1361
+ }
1374
1362
}
1375
1363
1376
- defm FMA16_ftz : FMA_F16 <"fma.rn.ftz.f16", f16, Int16Regs , doF32FTZ>;
1377
- defm FMA16 : FMA_F16 <"fma.rn.f16", f16, Int16Regs, True >;
1378
- defm FMA16x2_ftz : FMA_F16 <"fma.rn.ftz.f16x2", v2f16, Int32Regs , doF32FTZ>;
1379
- defm FMA16x2 : FMA_F16 <"fma.rn.f16x2", v2f16, Int32Regs, True >;
1380
- defm BFMA16 : FMA_BF16 <"fma.rn.bf16", bf16, Int16Regs, True >;
1381
- defm BFMA16x2 : FMA_BF16 <"fma.rn.bf16x2", v2bf16, Int32Regs, True >;
1382
- defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
1383
- defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True >;
1384
- defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True >;
1364
+ defm FMA16_ftz : FMA <"fma.rn.ftz.f16", F16RT, [useFP16Math , doF32FTZ] >;
1365
+ defm FMA16 : FMA <"fma.rn.f16", F16RT, [useFP16Math] >;
1366
+ defm FMA16x2_ftz : FMA <"fma.rn.ftz.f16x2", F16X2RT, [useFP16Math , doF32FTZ] >;
1367
+ defm FMA16x2 : FMA <"fma.rn.f16x2", F16X2RT, [useFP16Math] >;
1368
+ defm BFMA16 : FMA <"fma.rn.bf16", BF16RT, [hasBF16Math] >;
1369
+ defm BFMA16x2 : FMA <"fma.rn.bf16x2", BF16X2RT, [hasBF16Math] >;
1370
+ defm FMA32_ftz : FMA<"fma.rn.ftz.f32", F32RT, [ doF32FTZ] >;
1371
+ defm FMA32 : FMA<"fma.rn.f32", F32RT >;
1372
+ defm FMA64 : FMA<"fma.rn.f64", F64RT >;
1385
1373
1386
1374
// sin/cos
1387
1375
@@ -1999,7 +1987,7 @@ multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
1999
1987
Requires<[doF32FTZ]>;
2000
1988
def : Pat<(i1 (OpNode f32:$a, f32:$b)),
2001
1989
(SETP_f32rr $a, $b, Mode)>;
2002
- def : Pat<(i1 (OpNode Float32Regs :$a, fpimm:$b)),
1990
+ def : Pat<(i1 (OpNode f32 :$a, fpimm:$b)),
2003
1991
(SETP_f32ri $a, fpimm:$b, ModeFTZ)>,
2004
1992
Requires<[doF32FTZ]>;
2005
1993
def : Pat<(i1 (OpNode f32:$a, fpimm:$b)),
@@ -2056,7 +2044,7 @@ def SDTStoreParamProfile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>;
2056
2044
def SDTStoreParamV2Profile : SDTypeProfile<0, 4, [SDTCisInt<0>, SDTCisInt<1>]>;
2057
2045
def SDTStoreParamV4Profile : SDTypeProfile<0, 6, [SDTCisInt<0>, SDTCisInt<1>]>;
2058
2046
def SDTStoreParam32Profile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>;
2059
- def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisInt<0 >]>;
2047
+ def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisVT<0, i32>, SDTCisVT<1, i32 >]>;
2060
2048
def SDTCallArgMarkProfile : SDTypeProfile<0, 0, []>;
2061
2049
def SDTCallVoidProfile : SDTypeProfile<0, 1, []>;
2062
2050
def SDTCallValProfile : SDTypeProfile<1, 0, []>;
@@ -2352,42 +2340,10 @@ def CallArgEndInst1 : NVPTXInst<(outs), (ins), ");", [(CallArgEnd (i32 1))]>;
2352
2340
def CallArgEndInst0 : NVPTXInst<(outs), (ins), ")", [(CallArgEnd (i32 0))]>;
2353
2341
def RETURNInst : NVPTXInst<(outs), (ins), "ret;", [(RETURNNode)]>;
2354
2342
2355
- class CallArgInst<NVPTXRegClass regclass> :
2356
- NVPTXInst<(outs), (ins regclass:$a), "$a, ",
2357
- [(CallArg (i32 0), regclass:$a)]>;
2358
-
2359
- class CallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
2360
- NVPTXInst<(outs), (ins regclass:$a), "$a, ",
2361
- [(CallArg (i32 0), vt:$a)]>;
2362
-
2363
- class LastCallArgInst<NVPTXRegClass regclass> :
2364
- NVPTXInst<(outs), (ins regclass:$a), "$a",
2365
- [(LastCallArg (i32 0), regclass:$a)]>;
2366
- class LastCallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
2367
- NVPTXInst<(outs), (ins regclass:$a), "$a",
2368
- [(LastCallArg (i32 0), vt:$a)]>;
2369
-
2370
- def CallArgI64 : CallArgInst<Int64Regs>;
2371
- def CallArgI32 : CallArgInstVT<Int32Regs, i32>;
2372
- def CallArgI16 : CallArgInstVT<Int16Regs, i16>;
2373
- def CallArgF64 : CallArgInst<Float64Regs>;
2374
- def CallArgF32 : CallArgInst<Float32Regs>;
2375
-
2376
- def LastCallArgI64 : LastCallArgInst<Int64Regs>;
2377
- def LastCallArgI32 : LastCallArgInstVT<Int32Regs, i32>;
2378
- def LastCallArgI16 : LastCallArgInstVT<Int16Regs, i16>;
2379
- def LastCallArgF64 : LastCallArgInst<Float64Regs>;
2380
- def LastCallArgF32 : LastCallArgInst<Float32Regs>;
2381
-
2382
- def CallArgI32imm : NVPTXInst<(outs), (ins i32imm:$a), "$a, ",
2383
- [(CallArg (i32 0), (i32 imm:$a))]>;
2384
- def LastCallArgI32imm : NVPTXInst<(outs), (ins i32imm:$a), "$a",
2385
- [(LastCallArg (i32 0), (i32 imm:$a))]>;
2386
-
2387
2343
def CallArgParam : NVPTXInst<(outs), (ins i32imm:$a), "param$a, ",
2388
- [(CallArg (i32 1), (i32 imm:$a) )]>;
2344
+ [(CallArg 1, imm:$a)]>;
2389
2345
def LastCallArgParam : NVPTXInst<(outs), (ins i32imm:$a), "param$a",
2390
- [(LastCallArg (i32 1), (i32 imm:$a) )]>;
2346
+ [(LastCallArg 1, imm:$a)]>;
2391
2347
2392
2348
def CallVoidInst : NVPTXInst<(outs), (ins ADDR_base:$addr), "$addr, ",
2393
2349
[(CallVoid (Wrapper tglobaladdr:$addr))]>;
0 commit comments